-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor and simplify hook design & add Tensor.register_hook API #31775
Changes from all commits
c0b947f
b4b3e9f
16b3dcd
2553179
2fac74f
de8b2df
665b15b
118cc07
aa68578
e8f799a
21eceec
d5468e5
c0838dc
dbd3c34
fe79a89
7c9fd70
11c26a9
ef087a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -384,8 +384,8 @@ static platform::Place GetPlaceOfVar( | |
|
||
void GradientAccumulator::AccumulateGrad() { | ||
/** | ||
* If the gradient has been calculated by previous graph, | ||
* it should be added to the previous graph result. | ||
* If the leaf gradient has been calculated done, the inner_var_ | ||
* should be added to the var_. | ||
*/ | ||
if (!var_->IsLeafGrad() || !SumGradCompleted() || !HasInnerVar()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. !HasInnerVar() 这个应该能去掉了 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个不能吧,现在每次调用AccumulatedGrad仍然要求有InnerVar的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 嗯对 |
||
return; | ||
|
@@ -396,7 +396,7 @@ void GradientAccumulator::AccumulateGrad() { | |
"this auto-grad")); | ||
PADDLE_ENFORCE_EQ(inner_var_->Var().IsInitialized(), true, | ||
platform::errors::InvalidArgument( | ||
"Interior var of Leaf tensor should be initialized.")); | ||
"Interior var of Leaf tensor should be initialized.")); | ||
auto* src = inner_var_->MutableVar(); | ||
auto* dst = var_->MutableVar(); | ||
if (!var_->IsEmpty()) { | ||
|
@@ -427,10 +427,65 @@ void GradientAccumulator::AccumulateGrad() { | |
*(dst) = std::move(*src); | ||
var_->SetType(inner_var_->Type()); | ||
var_->SetDataType(inner_var_->DataType()); | ||
var_->SetIsEmpty(false); | ||
} | ||
inner_var_.reset(); | ||
} | ||
|
||
void GradientAccumulator::CallGradientHooks() { | ||
PADDLE_ENFORCE_EQ(var_->IsLeafGrad(), true, | ||
platform::errors::Unavailable( | ||
"Only leaf gradient Tensor can deal with by gradient " | ||
"hook in gradient accumulator.")); | ||
PADDLE_ENFORCE_EQ( | ||
SumGradCompleted(), true, | ||
platform::errors::PreconditionNotMet( | ||
"Only can call gradient hooks after sum gradient completed.")); | ||
PADDLE_ENFORCE_EQ( | ||
HasInnerVar(), true, | ||
platform::errors::PreconditionNotMet( | ||
"Leaf Tensor's inner var is nullptr when call gradient hook.")); | ||
PADDLE_ENFORCE_EQ( | ||
inner_var_->Var().IsInitialized(), true, | ||
platform::errors::PreconditionNotMet("Leaf Tensor's inner var " | ||
"is not initialized when " | ||
"call gradient hook.")); | ||
if (var_->HasHook()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seal this or make it has difference with the same code in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. only for loop is similar |
||
VLOG(3) << "Call " << var_->GetHooks().size() | ||
<< " hooks of leaf gradient accumulator's inner var `" | ||
<< var_->Name() << "`."; | ||
auto tmp_var = inner_var_; | ||
VLOG(3) << "Input var " << var_->Name() << "'s hook size - " | ||
<< var_->GetHooks().size(); | ||
for (const auto& hook_pair : var_->GetHooks()) { | ||
tmp_var = (*hook_pair.second)(tmp_var); | ||
} | ||
inner_var_ = tmp_var; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 叶子节点在GradientAccumulator里面做CallGradientHooks就会替代自己内部的inner_var_,相当于inplace了吧 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是的,本来就是inplace的,这里改成这样,主要目的是统一hook的基类管理和调用,如果这里使用InplaceHook,那之前的HookPipeLine那些就仍然需要,数据结构和逻辑都会比较复杂 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
} | ||
} | ||
|
||
void GradientAccumulator::CallReduceHooks() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do some check to differ it with normal hook There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
PADDLE_ENFORCE_EQ( | ||
var_->IsLeafGrad(), true, | ||
platform::errors::Unavailable("Only leaf gradient Tensor can deal with " | ||
"by reduce hook in gradient accumulator.")); | ||
PADDLE_ENFORCE_EQ(SumGradCompleted(), true, | ||
platform::errors::PreconditionNotMet( | ||
"Only can call reduce hooks after the gradient " | ||
"summation is completed in current batch.")); | ||
PADDLE_ENFORCE_EQ(HasInnerVar(), false, | ||
platform::errors::PreconditionNotMet( | ||
"Only can call reduce hooks after the " | ||
"gradient accumulation is completed in " | ||
"current batch or across batchs.")); | ||
if (var_->HasMutableHook()) { | ||
for (const auto& hook : var_->GetMutableHooks()) { | ||
VLOG(3) << "call gradient accumulator backward hooks."; | ||
(*hook)(var_); | ||
} | ||
} | ||
} | ||
|
||
void EagerGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var, | ||
size_t trace_id, bool unchange_input) { | ||
/** | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bad name, or may use inherent to fix it? CallHooks indicates invoke all hooks, but CallReduceHooks make it confused to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done,
CallHooks
->CallGradientHooks