-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor and simplify hook design & add Tensor.register_hook API #31775
Refactor and simplify hook design & add Tensor.register_hook API #31775
Conversation
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some comments
@@ -408,9 +412,25 @@ void BasicEngine::Execute() { | |||
} | |||
} | |||
|
|||
for (auto& pair : tmp_ins) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about create tmp_ins only when it needed, it seems make too many tmp variable_wrapper copy here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
accumulator->CallBackwardPostHooks(); | ||
} | ||
// 3. Call backward Hooks for `var_` | ||
accumulator->CallReduceHooks(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bad name, or may use inherent to fix it? CallHooks indicates invoke all hooks, but CallReduceHooks make it confused to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, CallHooks
-> CallGradientHooks
platform::errors::InvalidArgument("Leaf Tensor's inner var " | ||
"is not initialized when " | ||
"call gradient hook.")); | ||
if (var_->HasHook()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seal this or make it has difference with the same code in Execute
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only for loop is similar
} | ||
} | ||
|
||
void GradientAccumulator::CallReduceHooks() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do some check to differ it with normal hook
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
* parallel multi-card training. | ||
*/ | ||
|
||
void CallHooks(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make this two func not a parallel structure with related name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
*/ | ||
class OpBasePreHook { | ||
class VariableWrapperHook { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about make a abstract class of Hook to seal different kinds of hooks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tried, that is not a good idea
int64_t next_hook_id_{0}; | ||
// Hooks used to register hook for grad var, support adding and removing, | ||
// key is the accumulated int64_t value | ||
std::map<int64_t, std::shared_ptr<VariableWrapperHook>> hooks_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why map here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the hook remove helper need to hold hook id for removing it correctlly
… hook/refactor_hook_impl_and_add_py_api
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* If the gradient has been calculated by previous graph, | ||
* it should be added to the previous graph result. | ||
* If the leaf gradient has been calculated done, the inner_var_ | ||
* should be added to the var_. | ||
*/ | ||
if (!var_->IsLeafGrad() || !SumGradCompleted() || !HasInnerVar()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
!HasInnerVar() 这个应该能去掉了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个不能吧,现在每次调用AccumulatedGrad仍然要求有InnerVar的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯对
for (const auto& hook_pair : var_->GetHooks()) { | ||
tmp_var = (*hook_pair.second)(tmp_var); | ||
} | ||
inner_var_ = tmp_var; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
叶子节点在GradientAccumulator里面做CallGradientHooks就会替代自己内部的inner_var_,相当于inplace了吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,本来就是inplace的,这里改成这样,主要目的是统一hook的基类管理和调用,如果这里使用InplaceHook,那之前的HookPipeLine那些就仍然需要,数据结构和逻辑都会比较复杂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
7c9fd70
… hook/refactor_hook_impl_and_add_py_api
PR types
New features
PR changes
APIs
Describe
Refactor and simplify hook design & add Tensor.register_hook API
1. Refactor
Simplify Hook class design
2. Add Tensor.register_hook method
3. Doc
related cn doc: PaddlePaddle/docs#3390
英文由于文档抽取问题,现在无法预览