Skip to content

Commit

Permalink
add cross mesh r_to_s reshard func for auto parrallel. (#63962)
Browse files Browse the repository at this point in the history
* add cross mesh r_to_s reshard func for auto parrallel.

* fix ci
  • Loading branch information
winter-wang authored May 3, 2024
1 parent 8a19b43 commit 2562f58
Show file tree
Hide file tree
Showing 14 changed files with 357 additions and 185 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/dist_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "paddle/fluid/pir/dialect/distributed/ir/dist_api.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_tools.h"
#include "paddle/fluid/pir/dialect/distributed/transforms/dist_to_dense_pass.h"
#include "paddle/fluid/pir/dialect/distributed/transforms/mix_to_dist_pass.h"
#include "paddle/fluid/pybind/dist_api.h"
Expand Down Expand Up @@ -122,6 +123,7 @@ OperationDistAttribute CreateOperationDistAttribute(
void BindDistUtils(pybind11::module *m) {
m->def("create_tensor_dist_attribute", CreateTensorDistAttribute);
m->def("create_op_dist_attribute", CreateOperationDistAttribute);
m->def("cvt_to_dist_type", &dialect::CvtToPirDistType);
}

void BindDistPassAPI(pybind11::module *module) {
Expand Down
87 changes: 49 additions & 38 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "paddle/fluid/ir_adaptor/translator/utils.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_dialect.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_tools.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h"
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h"
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
Expand Down Expand Up @@ -815,45 +816,43 @@ py::str Value2String(Value self) {
return print_stream.str();
}

phi::DataType GetValueDtype(Value value) {
if (!value.type()) {
phi::DataType GetTensorDtype(Type type) {
if (!type) {
PADDLE_THROW(phi::errors::InvalidArgument("The type of value is nullptr."));
}
if (value.type().isa<DenseTensorType>()) {
return paddle::dialect::TransToPhiDataType(
value.type().dyn_cast<DenseTensorType>().dtype());
} else if (value.type().isa<SelectedRowsType>()) {
return paddle::dialect::TransToPhiDataType(
value.type().dyn_cast<SelectedRowsType>().dtype());
} else if (value.type().isa<DenseTensorArrayType>()) {
return paddle::dialect::TransToPhiDataType(
value.type().dyn_cast<DenseTensorArrayType>().dtype());
} else if (value.type().isa<DistDenseTensorType>()) {
return paddle::dialect::TransToPhiDataType(
value.type().dyn_cast<DistDenseTensorType>().dtype());
if (auto dense_tensor_type = type.dyn_cast<DenseTensorType>()) {
return dialect::TransToPhiDataType(dense_tensor_type.dtype());
} else if (auto select_rows = type.dyn_cast<SelectedRowsType>()) {
return dialect::TransToPhiDataType(select_rows.dtype());
} else if (auto dense_array = type.dyn_cast<DenseTensorArrayType>()) {
return dialect::TransToPhiDataType(dense_array.dtype());
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Currently, we can only get phi::DataType from DenseTensorType and "
"SelectedRowsType, DistDenseTensorType."));
"SelectedRowsType, DenseTensorArrayType."));
}
}
phi::DataType GetValueDtype(Value value) {
return GetTensorDtype(value.type());
}

const phi::DDim &GetValueDims(Value value) {
if (!value.type()) {
PADDLE_THROW(phi::errors::InvalidArgument("The type of value is nullptr."));
const phi::DDim &GetTensorDims(Type type) {
if (!type) {
PADDLE_THROW(common::errors::InvalidArgument(
"The type used to get dims is nullptr."));
}
if (value.type().isa<DenseTensorType>()) {
return value.type().dyn_cast<DenseTensorType>().dims();
} else if (value.type().isa<SelectedRowsType>()) {
return value.type().dyn_cast<SelectedRowsType>().dims();
} else if (value.type().isa<DistDenseTensorType>()) {
return value.type().dyn_cast<DistDenseTensorType>().global_ddim();
if (auto dense_type = type.dyn_cast<DenseTensorType>()) {
return dense_type.dims();
} else if (auto select_rows_type = type.dyn_cast<SelectedRowsType>()) {
return select_rows_type.dims();
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Currently, we can only get shape for dense and distdense"
"tensor."));
PADDLE_THROW(common::errors::InvalidArgument(
"Currently, we can only get shape for dense and selsect rows type."));
}
}
const phi::DDim &GetValueDims(Value value) {
return GetTensorDims(value.type());
}

pir::Value apply(Value self, py::object func) {
py::gil_scoped_acquire gil;
Expand Down Expand Up @@ -1100,13 +1099,10 @@ void BindValue(py::module *m) {
}
return self.type().dyn_cast<DistTypeInterface>().tensor_dist_attr();
})
// The function will calculate the new local shape based on the global
// shape and the dist_attr argument.
.def("update_dist_attr", [](Value &self, TensorDistAttribute dist_attr) {
if (auto dist_type = self.type().dyn_cast<DistTypeInterface>()) {
self.set_type(dist_type.CopyWithNewDistAttr(dist_attr));
} else {
PADDLE_THROW(common::errors::InvalidArgument(
"update_dist_attr is only for dist type tensor."));
}
self.set_type(dialect::CvtToPirDistType(self.type(), dist_attr));
});
}

Expand Down Expand Up @@ -1137,11 +1133,26 @@ bool GetValueBoolAttr(Value value, const std::string &attr_name) {

void BindType(py::module *m) {
py::class_<Type> ir_type(*m, "Type");
ir_type.def("__eq__", &Type::operator==).def("__str__", [](Type &self) {
std::ostringstream print_stream;
print_stream << self;
return print_stream.str();
});
ir_type.def("__eq__", &Type::operator==)
.def_property(
"shape",
[](Type self) { return phi::vectorize(GetTensorDims(self)); },
[](Type self, const std::vector<int> &shape) {
PADDLE_THROW(phi::errors::InvalidArgument(
"can't set shape when building static graph"));
})
.def_property(
"dtype",
[](Type self) { return GetTensorDtype(self); },
[](Type self, phi::DataType dtype) {
PADDLE_THROW(phi::errors::InvalidArgument(
"can't set dtype when building static graph"));
})
.def("__str__", [](Type &self) {
std::ostringstream print_stream;
print_stream << self;
return print_stream.str();
});

m->def("create_shaped_type",
[](Type &type, const std::vector<int> &shape) -> Type {
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/distributed/auto_parallel/static/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,12 +672,12 @@ def _parallel_pir(self, mode):
# TODO(JZ-LIANG) Step 3.1: Partition Pass
# insert reshard op if operand tensor's placements if different from what the cumsumer op need.
# Partition the computation graph into different pipeline stage if need.
dist_program = apply_partition_pass(dist_program)
apply_partition_pass(dist_program)

# TODO(hitywt) Step 3.2: Reshard Pass
# resolute the reshard op into special collective operation.
# collect the communicator created during resolution.
dist_program = apply_reshard_pass(dist_program)
apply_reshard_pass(dist_program)

# Part 4: Optimization Pass
# NOTE Only those Optimization Pass that related to Parallelism (need dist attr) should be placed here and all the Pass should be Optional.
Expand Down
167 changes: 79 additions & 88 deletions python/paddle/distributed/auto_parallel/static/pir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,99 +61,90 @@ def reshard_combine_value(op, operand, attr):


def apply_partition_pass(program):
with paddle.static.program_guard(program):
for op in program.global_block().ops:
if op.name() in partition_skip_op_list:
continue
assert len(op.operands()) == len(
op.dist_attr.operands()
), f"The number of operands and the number of op_dist_attr's operands are not equal in op: {op}"

for operand, attr in zip(op.operands(), op.dist_attr.operands()):
prev_var = operand.source()
if prev_var.is_combine():
operand.set_source(reshard_combine_value(op, operand, attr))
else:
operand.set_source(reshard_single_value(op, operand, attr))
prev_op = prev_var.get_defining_op()
if (
prev_op
and prev_op.num_results() == 1
and prev_var.use_empty()
):
prev_op.erase()

for var, attr in zip(op.results(), op.dist_attr.results()):
if (
var.initialized()
and var.is_dist()
and var.dist_attr() != attr
):
paddle.pir.set_insertion_point_after(op)
old_dist_attr = var.dist_attr()
var.update_dist_attr(attr.as_tensor_dist_attr())
# insert reshard
reshard_var = paddle._C_ops.reshard_v2(var, old_dist_attr)
var.replace_all_uses_with(reshard_var)
reshard_var.get_defining_op().operand(0).set_source(var)

# pruning op and value not belong to cur rank
cur_rank = paddle.distributed.get_rank()
for op in program.global_block().ops[::-1]:
if cur_rank not in op.dist_attr.process_mesh.process_ids:
program.global_block().remove_op(op)
for op in program.global_block().ops:
if op.name() in partition_skip_op_list:
continue
assert len(op.operands()) == len(
op.dist_attr.operands()
), f"The number of operands and the number of op_dist_attr's operands are not equal in op: {op}"

for operand, attr in zip(op.operands(), op.dist_attr.operands()):
prev_var = operand.source()
if prev_var.is_combine():
operand.set_source(reshard_combine_value(op, operand, attr))
else:
# set the operand as null when it is not belong to cur rank
if (
op.name() == 'dist_op.reshard'
and cur_rank
not in op.operand(0)
.source()
.dist_attr()
.process_mesh.process_ids
):
op.operand(0).set_source(None)

# merge pd.data ops for
lr_ops = []
for op in program.global_block().ops[::-1]:
operand.set_source(reshard_single_value(op, operand, attr))
prev_op = prev_var.get_defining_op()
if prev_op and prev_op.num_results() == 1 and prev_var.use_empty():
prev_op.erase()

for var, attr in zip(op.results(), op.dist_attr.results()):
if var.initialized() and var.is_dist() and var.dist_attr() != attr:
paddle.pir.set_insertion_point_after(op)
old_dist_attr = var.dist_attr()
var.update_dist_attr(attr.as_tensor_dist_attr())
# insert reshard
reshard_var = paddle._C_ops.reshard_v2(var, old_dist_attr)
var.replace_all_uses_with(reshard_var)
reshard_var.get_defining_op().operand(0).set_source(var)

# pruning op and value not belong to cur rank
cur_rank = paddle.distributed.get_rank()
for op in program.global_block().ops[::-1]:
if cur_rank not in op.dist_attr.process_mesh.process_ids:
op.erase()
else:
# set the operand as null when it is not belong to cur rank
if (
op.name() == 'pd_op.data'
and "learning_rate" in op.attrs()["name"]
op.name() == 'dist_op.reshard'
and cur_rank
not in op.operand(0)
.source()
.dist_attr()
.process_mesh.process_ids
):
lr_ops.append(op)

if len(lr_ops) > 1:
lr_value = lr_ops[0].result(0)
for op in lr_ops[1:]:
lr = op.result(0)
lr.replace_all_uses_with(lr_value)
program.global_block().remove_op(op)
return program
op.operand(0).set_source(None)

# merge pd.data ops for
lr_ops = []
for op in program.global_block().ops[::-1]:
if op.name() == 'pd_op.data' and "learning_rate" in op.attrs()["name"]:
lr_ops.append(op)

if len(lr_ops) > 1:
lr_value = lr_ops[0].result(0)
for op in lr_ops[1:]:
lr = op.result(0)
lr.replace_all_uses_with(lr_value)
op.erase()


def apply_reshard_pass(program):
new_program = program.clone()
with paddle.base.program_guard(new_program):
for op in new_program.global_block().ops:
if op.name() == 'dist_op.reshard':
var = op.operand_source(0)
op_dist_attr = op.dist_attr
src_dist_attr = op_dist_attr.operand(0).as_tensor_dist_attr()
dst_dist_attr = op_dist_attr.result(0).as_tensor_dist_attr()
assert (
not var.initialized() or var.dist_attr() == src_dist_attr
), f"The dist_attr of reshard op's input and operand should be equal, but got {var.dist_attr()} and {src_dist_attr}"

reshard_func = choose_reshard_func(src_dist_attr, dst_dist_attr)
assert (
reshard_func is not None
), f'There is no reshard function that matches src_dist_attr: {src_dist_attr} and dst_dist_attr: {dst_dist_attr}'
reshard_func.reshard(
new_program, op, src_dist_attr, dst_dist_attr
)

return new_program
for op in program.global_block().ops:
if op.name() == 'dist_op.reshard':
var = op.operand_source(0)
op_dist_attr = op.dist_attr
src_dist_attr = op_dist_attr.operand(0).as_tensor_dist_attr()
dst_dist_attr = op_dist_attr.result(0).as_tensor_dist_attr()
assert (
not var.initialized() or var.dist_attr() == src_dist_attr
), f"The dist_attr of reshard op's input and operand should be equal, but got {var.dist_attr()} and {src_dist_attr}"

reshard_func = choose_reshard_func(src_dist_attr, dst_dist_attr)
assert (
reshard_func is not None
), f'There is no reshard function that matches src_dist_attr: {src_dist_attr} and dst_dist_attr: {dst_dist_attr}'
paddle.pir.set_insertion_point_after(op)
out_value = reshard_func.reshard(
src_dist_attr,
dst_dist_attr,
op.operand_source(0),
op.result(0).type(),
)
if out_value is not None:
op.result(0).replace_all_uses_with(out_value)
if op.result(0).use_empty():
op.erase()


# In sequence_parallel, we need to transpose hidden_states
Expand Down Expand Up @@ -183,5 +174,5 @@ def eliminate_transpose_by_reshape(program):
transpose_var = op.result(0)
reshape_var = paddle._C_ops.reshape(var, transpose_var.shape)
transpose_var.replace_all_uses_with(reshape_var)
program.global_block().remove_op(op)
op.erase()
return program
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class ReshardFunction:
def is_suitable(self, dist_tensor, dist_attr):
raise NotImplementedError

def reshard(self, program, op, src_tensor, dst_dist_attr):
def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type):
raise NotImplementedError


Expand Down
Loading

0 comments on commit 2562f58

Please sign in to comment.