-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Auto Parallel]: Support std::vector<phi::Tensor> input and output for DistTensor. #56602
[Auto Parallel]: Support std::vector<phi::Tensor> input and output for DistTensor. #56602
Conversation
Concat forward and backward are verified.
… dist_tensor_support
… dist_tensor_support
…Tensor> -> std::vector<Tensor>.
… dist_tensor_support
…or output of operators. Following testcases are passed. 1. concat: std::vector<phi::Tensor> -> phi::Tensor 2. unbind: phi::Tensor -> std::vector<phi::Tensor> 3. broadcast_tensors: std::vector<phi::Tensor> -> std::vector<phi::Tensor>
phi::distributed::DistTensor* dist_tensor = | ||
static_cast<phi::distributed::DistTensor*>(tensor.impl().get()); | ||
intermidiate_tensor_.set_impl( | ||
std::make_shared<phi::distributed::DistTensor>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里新构造tensor会触发重切分,我们还需要讨论一下,这样能过的话,也可以先这样写
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前在concat上验证,是没有问题的。可以先加一个TODO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
没关系,后面我来处理下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
小问题可以下个pr改
const TransformFlag& transform_flag, | ||
bool is_stride_kernel) { | ||
std::vector<std::shared_ptr<phi::distributed::DistTensor>> out; | ||
for (auto x : input) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以改成auto&,减少拷贝
dense_tensor.meta().is_contiguous()))) { | ||
out.push_back( | ||
std::static_pointer_cast<phi::distributed::DistTensor>(tensor_in)); | ||
continue; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
理论上这里可以写到else分支里,不用continue
…r DistTensor. (PaddlePaddle#56602) * [WIP] Support std::vector<phi::Tensor> input and output for DistTensor. Concat forward and backward are verified. * Polish code for new dist tensor implementation. * Fix bug of DistTensor upgrade. Add support functions for std::vector<Tensor> -> std::vector<Tensor>. * Add support for DistTensor type of std::vector<phi::Tensor> as input or output of operators. Following testcases are passed. 1. concat: std::vector<phi::Tensor> -> phi::Tensor 2. unbind: phi::Tensor -> std::vector<phi::Tensor> 3. broadcast_tensors: std::vector<phi::Tensor> -> std::vector<phi::Tensor> * Polish code. Remove useless comments. * Add update_loss_scaling in skip_op_lists. * Polish code.
PR types
Others
PR changes
Others
Description
Pcard-73145
Support
std::vector<phi::Tensor>
input and output forDistTensor
. Meanwhileno_need_buffer
type is also supported.Concat
,Broadcast_Tensors
,Unbind
forward and backward are verified. Following operators need to be supported later (along with their backward op if exists):check_finite_and_unscale
,coalesce_tensor
,meshgrid
,update_loss_scaling
,einsum
.Foward operators have output of
std::tuple<...>
are not supported now.