Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

s_to_r reshard support unbalanced split #67756

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions paddle/fluid/pir/dialect/distributed/ir/dist_tools.cc
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,9 @@ pir::Attribute CreateReplicatedDistAttr(pir::Type prim_type,
}
return nullptr;
}
pir::Type CvtToPirDistType(pir::Type global_type, pir::Attribute dist_attr) {
pir::Type CvtToPirDistType(pir::Type global_type,
pir::Attribute dist_attr,
const std::vector<int64_t>& local_ddim) {
if (!global_type) return nullptr;
auto ctx = pir::IrContext::Instance();
if (auto dense_tensor_type = global_type.dyn_cast<pir::DenseTensorType>()) {
Expand All @@ -389,7 +391,14 @@ pir::Type CvtToPirDistType(pir::Type global_type, pir::Attribute dist_attr) {
"Only allowed convert a densor tensor type to dist dense tensor type "
"with non-empty TensorDistAttr"));
}
return DistDenseTensorType::get(ctx, dense_tensor_type, tensor_dist_attr);
if (!local_ddim.empty()) {
return DistDenseTensorType::get(ctx,
dense_tensor_type,
tensor_dist_attr,
common::make_ddim(local_ddim));
} else {
return DistDenseTensorType::get(ctx, dense_tensor_type, tensor_dist_attr);
}
} else if (auto vec_type = global_type.dyn_cast<pir::VectorType>()) {
auto array_attr = dist_attr.dyn_cast<pir::ArrayAttribute>();
if (!array_attr) {
Expand All @@ -406,7 +415,8 @@ pir::Type CvtToPirDistType(pir::Type global_type, pir::Attribute dist_attr) {
"The vector type size must equal to array attribute size."));
std::vector<pir::Type> dist_vec_type;
for (size_t idx = 0; idx < vec_type.size(); ++idx) {
dist_vec_type.push_back(CvtToPirDistType(vec_type[idx], array_attr[idx]));
dist_vec_type.push_back(
CvtToPirDistType(vec_type[idx], array_attr[idx], local_ddim));
}
return pir::VectorType::get(ctx, dist_vec_type);
} else {
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/pir/dialect/distributed/ir/dist_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ pir::Attribute CvtToPirAttr(const phi::distributed::ArgDistAttr& dist_attr);
pir::Attribute CreateReplicatedDistAttr(pir::Type prim_type,
ProcessMeshAttribute mesh);

pir::Type CvtToPirDistType(pir::Type global_type, pir::Attribute dist_attr);
pir::Type CvtToPirDistType(
pir::Type global_type,
pir::Attribute dist_attr,
const std::vector<int64_t>& local_ddim = std::vector<int64_t>());

///
/// When the following conditions are met:
Expand Down
13 changes: 12 additions & 1 deletion paddle/fluid/pybind/dist_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ using paddle::dialect::DistTypeInterface;
using paddle::dialect::OperationDistAttribute;
using paddle::dialect::ProcessMeshAttribute;
using paddle::dialect::TensorDistAttribute;
using pir::ArrayAttribute;

namespace paddle::pybind {

Expand Down Expand Up @@ -127,11 +128,21 @@ OperationDistAttribute CreateOperationDistAttribute(
pir::IrContext::Instance(), mesh, operands, results);
}

ArrayAttribute CreateArrayAttribute(
const std::vector<pir::Attribute> &elements) {
return ArrayAttribute::get(pir::IrContext::Instance(), elements);
}

void BindDistUtils(pybind11::module *m) {
m->def("create_tensor_dist_attribute", CreateTensorDistAttribute);
m->def("create_op_dist_attribute", CreateOperationDistAttribute);
m->def("create_array_attribute", CreateArrayAttribute);
m->def("get_sub_meshes", phi::distributed::GetSubMeshes);
m->def("cvt_to_dist_type", &dialect::CvtToPirDistType);
m->def("cvt_to_dist_type",
&dialect::CvtToPirDistType,
py::arg("global_type"),
py::arg("dist_attr"),
py::arg("local_ddim") = std::vector<int64_t>());
}

void BindDistPassAPI(pybind11::module *module) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ def infer_allgather_dist_type(self, in_value, split_axis):
# may be shard and it will call this 1-D s_to_r function on each
# axis. In this case, we should recompute the local and global shape.
out_local_shape = list(in_value.shape)
out_local_shape[split_axis] = (
in_value.shape[split_axis] // mesh.shape[split_mesh_dim]
out_local_shape[split_axis] = int(
(in_value.shape[split_axis] + mesh.shape[split_mesh_dim] - 1)
/ mesh.shape[split_mesh_dim]
)
out_global_shape = list(out_local_shape)
out_global_shape[0] *= mesh.shape[split_mesh_dim]
Expand Down Expand Up @@ -96,10 +97,8 @@ def get_split_axis_with_dims_mapping(dims_mapping):
for k, v in split_axis_map.items():
split_axis = k
break

num_of_padding = (
src_value.shape[split_axis] % src_dist_attr.process_mesh.size
)
num_of_process = src_dist_attr.process_mesh.size
num_of_padding = src_value.shape[split_axis] % num_of_process
is_balanced_split = num_of_padding == 0

if is_balanced_split:
Expand All @@ -113,8 +112,102 @@ def get_split_axis_with_dims_mapping(dims_mapping):
)
return new_value
else:
# TODO(ywt01) support unbalanced split
raise NotImplementedError("unbalanced split is not implemented")
# find the last one
need_padding = (
paddle.distributed.get_rank()
== src_dist_attr.process_mesh.process_ids[-1]
)

# get padding_num
avg_size_on_split_axis = int(
(src_value.shape[split_axis] + num_of_process - 1)
/ num_of_process
)
padding_num = (
avg_size_on_split_axis * num_of_process
- src_value.shape[split_axis]
)
if need_padding:
# set right _local_shape
local_shape_at_split_axis = src_value.shape[
split_axis
] - avg_size_on_split_axis * (num_of_process - 1)
local_shape = src_value._local_shape
local_shape[split_axis] = local_shape_at_split_axis
tmp_src_type = paddle.base.libpaddle.pir.cvt_to_dist_type(
src_value.type(), src_dist_attr, list(local_shape)
)
src_value.set_type(tmp_src_type)
padding_shape = src_value._local_shape
padding_shape[split_axis] = padding_num
padding_tensor = paddle.full(
padding_shape,
0.0,
src_value.dtype,
)
tmp_src_type1 = paddle.base.libpaddle.pir.cvt_to_dist_type(
padding_tensor.type(), dst_dist_attr
)
padding_tensor.set_type(tmp_src_type1)
padding_tensor.get_defining_op().dist_attr = (
paddle.base.libpaddle.pir.create_op_dist_attribute(
dst_dist_attr.process_mesh, [], [dst_dist_attr]
)
)

concat_value = paddle._C_ops.concat(
[src_value, padding_tensor], split_axis
)
# set concat dist_attr
axis_dist_attr = (
paddle.base.libpaddle.pir.create_tensor_dist_attribute(
src_dist_attr.process_mesh, [-1], {}
)
)
concat_value.get_defining_op().dist_attr = (
paddle.base.libpaddle.pir.create_op_dist_attribute(
src_dist_attr.process_mesh,
[
paddle.base.libpaddle.pir.create_array_attribute(
[src_dist_attr, dst_dist_attr]
),
axis_dist_attr,
],
[src_dist_attr],
)
)
# set concat_value type
concat_global_shape = list(src_value.shape)
concat_global_shape[split_axis] = (
avg_size_on_split_axis * num_of_process
)
concat_type = paddle.pir.create_shaped_type(
src_value.type(), concat_global_shape
)
concat_type = paddle.base.libpaddle.pir.cvt_to_dist_type(
concat_type, src_dist_attr
)
concat_value.set_type(concat_type)

new_value = self.reshard_s_to_r_with_padding(
concat_value,
split_axis,
src_dist_attr,
dst_dist_attr,
dst_type,
padding_num,
)
return new_value
else:
new_value = self.reshard_s_to_r_with_padding(
src_value,
split_axis,
src_dist_attr,
dst_dist_attr,
dst_type,
padding_num,
)
return new_value

def reshard_s_to_r_with_padding(
self,
Expand Down Expand Up @@ -168,14 +261,29 @@ def reshard_s_to_r_with_padding(
)
pd_splite_op.result(0).set_type(vec_type)

concat_value = paddle._C_ops.concat(split_values, split_axis)
# fold builtin.split op and builtin.combine op
concat_op = concat_value.get_defining_op()
builtin_combine_op = concat_op.operand_source(0).get_defining_op()
concat_op.operand(0).set_source(pd_splite_op.result(0))
builtin_combine_op.erase()
builtin_split_op.erase()
return concat_value
if padding_num != 0:
tmp_split_values = paddle._C_ops.split(
split_values[-1],
[
split_values[-1].shape[split_axis] - padding_num,
padding_num,
],
split_axis,
)
split_values[-1] = tmp_split_values[0]
concat_value = paddle._C_ops.concat(split_values, split_axis)
return concat_value
else:
concat_value = paddle._C_ops.concat(split_values, split_axis)
# fold builtin.split op and builtin.combine op
concat_op = concat_value.get_defining_op()
builtin_combine_op = concat_op.operand_source(
0
).get_defining_op()
concat_op.operand(0).set_source(pd_splite_op.result(0))
builtin_combine_op.erase()
builtin_split_op.erase()
return concat_value
return allgather_value


Expand Down
Loading