-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dy2s] speed up PartialProgram.__call__ #58771
[dy2s] speed up PartialProgram.__call__ #58771
Conversation
88f7b16
to
a041eb0
Compare
e1b2ed0
to
be0f5a9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
autograd_meta->SetStopGradient(var_desc.StopGradient()); | ||
|
||
if (var_type == paddle::framework::proto::VarType::LOD_TENSOR) { | ||
// TODO(jiabin): Maybe support LOD later |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个删了吧。
if (!autograd_meta->GetMutableGradNode()) { | ||
autograd_meta->SetGradNode( | ||
std::make_shared<egr::GradNodeAccumulation>(autograd_meta)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不需要,删了
if (PyList_Check(var_desc_list)) { | ||
Py_ssize_t len = PyList_Size(var_desc_list); | ||
for (Py_ssize_t i = 0; i < len; i++) { | ||
auto var_desc = PyObjectCast<paddle::framework::VarDesc>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否可以使用 paddle::framework::VarDesc& ,防止一遍拷贝?
auto var_desc = PyObjectCast<paddle::framework::VarDesc>( | ||
PyList_GetItem(var_desc_list, i)); | ||
auto var_name = var_desc.Name(); | ||
if (out_tensor_map.find(var_name) == out_tensor_map.end()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个Map在PIR下是可以删除的。
In sot, inputs and outputs of partial program only contain tensors, so we can skip some step to speed up | ||
""" | ||
out_vars = self._prepare_outputs() | ||
attrs = self._prepare_attributes() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果SOT将AMP+training考虑了guard,不同的AMP状态会触发转写 / PrtialProgramLayer的重新构造,就可以将这部分加个 cache_property 中。
* move Tensor construction to cpp * mv _remove_no_value to ASTStaticFunction * update
PR types
Others
PR changes
Others
Description
PCard-66972