From dd0aafdd51dee1f191939b6fa84fd152bbf1f55f Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Sun, 9 Jan 2022 11:57:03 +0000 Subject: [PATCH 01/11] Add the backward support for QR --- paddle/fluid/operators/qr_op.h | 123 +++++++++++++++- paddle/fluid/operators/svd_helper.h | 135 ++++++++++++++++++ .../fluid/tests/unittests/CMakeLists.txt | 1 + .../fluid/tests/unittests/test_qr_op.py | 101 +++++++++++++ 4 files changed, 358 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/qr_op.h b/paddle/fluid/operators/qr_op.h index 73ba52f590c0d..65dfb4261e96e 100644 --- a/paddle/fluid/operators/qr_op.h +++ b/paddle/fluid/operators/qr_op.h @@ -19,6 +19,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/operators/svd_helper.h" #include "paddle/fluid/platform/for_range.h" namespace paddle { @@ -79,9 +80,11 @@ class QrCPUKernel : public framework::OpKernel { q_data = q.mutable_data>( context.GetPlace(), size_t(batch_size * m * k * sizeof(math::Real))); + memset(q_data, 0, size_t(batch_size * m * k * sizeof(math::Real))); } auto* r_data = r.mutable_data>( context.GetPlace(), size_t(batch_size * k * n * sizeof(math::Real))); + memset(r_data, 0, size_t(batch_size * k * n * sizeof(math::Real))); // Implement QR by calling Eigen for (int i = 0; i < batch_size; ++i) { @@ -126,8 +129,124 @@ template class QrGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { - PADDLE_THROW(platform::errors::InvalidArgument( - "QR doesn't have the backward kernel now and will be supported soon.")); + const framework::Tensor& Q = *ctx.Input("Q"); + const framework::Tensor& R = *ctx.Input("R"); + // Use a different name A instead of X + const framework::Tensor& A = *ctx.Input("X"); + const framework::Tensor& dQ = + *ctx.Input(framework::GradVarName("Q")); + const framework::Tensor& dR = + *ctx.Input(framework::GradVarName("R")); + // Use a different name dA instead of dX + framework::Tensor& dA = + *ctx.Output(framework::GradVarName("X")); + dA.mutable_data>(ctx.GetPlace()); + auto& dev_ctx = ctx.template device_context(); + math::SetConstant()(dev_ctx, &dA, T(0)); + + auto dito = math::DeviceIndependenceTensorOperations(ctx); + + std::string mode = ctx.Attr("mode"); + bool compute_q, reduced; + std::tie(compute_q, reduced) = _parse_qr_mode(mode); + if (!compute_q) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The derivative of qr is not implemented when mode='r'.")); + } + + auto a_dims = A.dims(); + int a_rank = a_dims.size(); + int m = a_dims[a_rank - 2]; + int n = a_dims[a_rank - 1]; + + if ((m > n) && (!reduced)) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The derivative of qr is not implemented when mode='complete' and " + "nrows > ncols.")); + } + + // m >= n case + auto m_gt_n_case = []( + const framework::ExecutionContext& ctx, + math::DeviceIndependenceTensorOperations& dito, + const Tensor& dQ, const Tensor& dR, const Tensor& A, const Tensor& Q, + const Tensor& R) -> framework::Tensor { + // Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang (2019). Differentiable + // Programming Tensor Networks. + // https://arxiv.org/abs/1903.09650 Section 3. QR factorization + + // dR^H + framework::Tensor R_term; + if (ctx.HasInput(framework::GradVarName("R"))) { + R_term = dito.Matmul(R, dito.Transpose(dR)); + } else { + R_term = dito.Fill(framework::vectorize(R.dims()), 0); + } + + // dQ^H * Q + framework::Tensor Q_term; + if (ctx.HasInput(framework::GradVarName("Q"))) { + Q_term = dito.Matmul(dito.Transpose(dQ), Q); + } else { + Q_term = dito.Fill(framework::vectorize(R.dims()), 0); + } + + framework::Tensor M_tmp1 = dito.Sub(R_term, Q_term); + + // Compute M = (tril(M) + tril(M).mH()) * 0.5 Identity + framework::Tensor M_tril_0 = dito.TrilTriu(M_tmp1, 0, true); + framework::Tensor M_tril_1 = dito.TrilTriu(M_tmp1, -1, true); + framework::Tensor M = dito.Add(M_tril_0, dito.Transpose(M_tril_1)); + + framework::Tensor rhs_term; + if (ctx.HasInput(framework::GradVarName("Q"))) { + rhs_term = dito.Add(dQ, dito.Matmul(Q, M)); + } else { + rhs_term = dito.Matmul(Q, M); + } + + // dA * R^H = rhs_term + auto dA = + dito.TriangularSolve(dito.Transpose(dito.Conj(dito.Transpose(R))), + dito.Transpose(rhs_term), + /*upper=*/true, + /*transpose=*/false, + /*unitriangular=*/false); + + return dito.Transpose(dA); + }; + + if (m >= n) { + auto dA_tmp = m_gt_n_case(ctx, dito, dQ, dR, A, Q, R); + framework::TensorCopy(dA_tmp, dA.place(), &dA); + } else { + // If m < n for input matrices A, we partition A = [X|Y] and R = [U|V] + // Calculate dX and dY individually and concatenate them to get dA + dA.mutable_data>(ctx.GetPlace()); + + auto Y = dito.Slice(A, {-1}, {m}, {n}); + auto U = dito.Slice(R, {-1}, {0}, {m}); + framework::Tensor dY, dX, dV, dR_tmp, dQ_prime; + + if (ctx.HasInput(framework::GradVarName("R"))) { + dV = dito.Slice(dR, {-1}, {m}, {n}); + dR_tmp = dito.Slice(dR, {-1}, {0}, {m}); + // Y * dV^H + dQ_prime = dito.Matmul(Y, dito.Transpose(dV)); + } else { + dV = dito.Fill(framework::vectorize(Y.dims()), 0); + dQ_prime = dito.Fill(framework::vectorize(Q.dims()), 0); + } + + if (ctx.HasInput(framework::GradVarName("Q"))) { + dQ_prime = dito.Add(dQ_prime, dQ); + } + dX = m_gt_n_case(ctx, dito, dQ_prime, dR_tmp, A, Q, U); + dY = dito.Matmul(Q, dV); + // Concatenate dX and dY to get dA. + auto dA_tmp = dito.ConcatTwoTensors(dX, dY, -1); + framework::TensorCopy(dA_tmp, dA.place(), &dA); + } } }; diff --git a/paddle/fluid/operators/svd_helper.h b/paddle/fluid/operators/svd_helper.h index 6b2584682277e..8d17ddec6fbb4 100644 --- a/paddle/fluid/operators/svd_helper.h +++ b/paddle/fluid/operators/svd_helper.h @@ -146,6 +146,93 @@ static std::vector GetBroadcastShape(InTensors ins) { return broadcast_shape; } +static inline framework::DDim ComputeAndCheckShapeForConcatOp( + const bool is_runtime, const std::vector& inputs_dims, + const size_t axis) { + const size_t n = inputs_dims.size(); + auto out_dims = inputs_dims[0]; + size_t in_zero_dims_size = out_dims.size(); + for (size_t i = 1; i < n; i++) { + PADDLE_ENFORCE_EQ(inputs_dims[i].size(), out_dims.size(), + platform::errors::InvalidArgument( + "The shape of input[0] and input[%d] " + "is expected to be equal." + "But received input[0]'s shape = " + "[%s], input[%d]'s shape = [%s].", + i, inputs_dims[0], i, inputs_dims[i])); + for (size_t j = 0; j < in_zero_dims_size; j++) { + if (j == axis) { + if (is_runtime) { + out_dims[axis] += inputs_dims[i][j]; + } else { + if (inputs_dims[i][j] == -1 || out_dims[j] == -1) { + out_dims[axis] = -1; + } else { + out_dims[axis] += inputs_dims[i][j]; + } + } + } else { + bool check_shape = + is_runtime || (inputs_dims[0][j] > 0 && inputs_dims[i][j] > 0); + if (check_shape) { + // check all shape in run time + PADDLE_ENFORCE_EQ(inputs_dims[0][j], inputs_dims[i][j], + platform::errors::InvalidArgument( + "The %d-th dimension of input[0] and input[%d] " + "is expected to be equal." + "But received input[0]'s shape = " + "[%s], input[%d]'s shape = [%s].", + j, i, inputs_dims[0], i, inputs_dims[i])); + } + if (!is_runtime && out_dims[j] == -1 && inputs_dims[i][j] > 0) { + out_dims[j] = inputs_dims[i][j]; + } + } + } + } + return out_dims; +} + +static inline int64_t ComputeAxisForConcatOp(int64_t axis, int64_t rank) { + PADDLE_ENFORCE_EQ( + axis >= -rank && axis < rank, true, + platform::errors::InvalidArgument( + "The axis is expected to be in range of [%d, %d), but got %d", -rank, + rank, axis)); + if (axis < 0) { + axis = axis + rank; + } + return axis > 0 ? axis : 0; +} + +// Prepared for the broadcast operation +static std::vector get_broadcast_batch_portion( + std::vector x, std::vector y) { + size_t size_x = x.size(); + size_t size_y = y.size(); + size_t size = std::max(size_x, size_y); + std::vector batchPortion(size); + + ptrdiff_t i = (ptrdiff_t)size - 1; + for (; i >= 0; --i) { + ptrdiff_t offset = size - i - 1; + ptrdiff_t dim_x = size_x - offset - 1; + ptrdiff_t dim_y = size_y - offset - 1; + int64_t x_size = (dim_x >= 0) ? x[dim_x] : 1; + int64_t y_size = (dim_y >= 0) ? y[dim_y] : 1; + + PADDLE_ENFORCE_EQ( + (x_size == y_size || x_size == 1 || y_size == 1), true, + platform::errors::PreconditionNotMet( + "The size of tensor x (%d) must match the size of tensor y " + "(%d) at non-singleton dimension %d.", + x_size, y_size, i)); + + batchPortion[i] = x_size != 1 ? x_size : y_size; + } + return batchPortion; +} + #define DITO_TRANSPOSE_RANK_CASE(N) \ case N: { \ math::Transpose trans; \ @@ -515,6 +602,54 @@ struct DeviceIndependenceTensorOperations { return CreateOpRunAndReturnTensor("tril_triu", inputs, attrs, out_shape); } + framework::Tensor TriangularSolve(const framework::Tensor& x, + const framework::Tensor& y, bool upper, + bool transpose, bool unitriangular) { + framework::AttributeMap attrs; + attrs["upper"] = upper; + attrs["transpose"] = transpose; + attrs["unitriangular"] = unitriangular; + NameInTensorMap inputs({{"X", {&x}}, {"Y", {&y}}}); + auto x_dims = x.dims(); + auto y_dims = y.dims(); + auto y_dims_n = y_dims.size(); + std::vector x_dims_vec = + paddle::framework::vectorize(x_dims); + std::vector y_dims_vec = + paddle::framework::vectorize(y_dims); + std::vector x_dims_vec_cut(x_dims_vec.begin(), + x_dims_vec.end() - 2); + std::vector y_dims_vec_cut(y_dims_vec.begin(), + y_dims_vec.end() - 2); + std::vector expand_batch_portion = + get_broadcast_batch_portion(x_dims_vec_cut, y_dims_vec_cut); + std::vector y_broadcast_dims({expand_batch_portion}); + y_broadcast_dims.insert(y_broadcast_dims.end(), {y_dims_vec[y_dims_n - 2], + y_dims_vec[y_dims_n - 1]}); + std::vector out_shape(y_broadcast_dims.begin(), + y_broadcast_dims.end()); + return CreateOpRunAndReturnTensor("triangular_solve", inputs, attrs, + out_shape); + } + + framework::Tensor ConcatTwoTensors(const framework::Tensor& x, + const framework::Tensor& y, int axis) { + framework::AttributeMap attrs; + attrs["axis"] = axis; + std::vector inputs_dims({x.dims(), y.dims()}); + NameInTensorMap inputs({{"X", {&x, &y}}}); + size_t axis_ = + ComputeAxisForConcatOp(static_cast(axis), + static_cast(inputs_dims[0].size())); + framework::DDim out_dims = + ComputeAndCheckShapeForConcatOp(true, inputs_dims, axis_); + if (out_dims[axis_] < 0) { + out_dims[axis_] = -1; + } + std::vector out_shape = framework::vectorize(out_dims); + return CreateOpRunAndReturnTensor("concat", inputs, attrs, out_shape); + } + Tensor Conj(const Tensor& x) { Tensor out; auto* out_data = out.mutable_data(x.dims(), context.GetPlace()); diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index e987255f47b65..1e2583b2008ff 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -975,6 +975,7 @@ set_tests_properties(test_lstm_cudnn_op PROPERTIES TIMEOUT 120) set_tests_properties(test_stack_op PROPERTIES TIMEOUT 120) set_tests_properties(test_bilinear_interp_v2_op PROPERTIES TIMEOUT 120) set_tests_properties(test_svd_op PROPERTIES TIMEOUT 80) +set_tests_properties(test_qr_op PROPERTIES TIMEOUT 60) set_tests_properties(test_deformable_psroi_pooling PROPERTIES TIMEOUT 120) set_tests_properties(test_trilinear_interp_v2_op PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_static_runner_mnist PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/test_qr_op.py b/python/paddle/fluid/tests/unittests/test_qr_op.py index ea2aaf3f00d5b..b0d072c6ad342 100644 --- a/python/paddle/fluid/tests/unittests/test_qr_op.py +++ b/python/paddle/fluid/tests/unittests/test_qr_op.py @@ -21,6 +21,107 @@ import paddle.fluid as fluid import paddle.fluid.layers as layers import paddle.fluid.core as core +from op_test import OpTest + + +class TestQrOp(OpTest): + def setUp(self): + paddle.enable_static() + np.random.seed(4) + # self._cpu_only = True + self.op_type = "qr" + a, q, r = self.get_input_and_output() + self.inputs = {"X": a} + self.attrs = {"mode": self.get_mode()} + self.outputs = {"Q": q, "R": r} + + def get_dtype(self): + return "float64" + + def get_mode(self): + return "reduced" + + def get_shape(self): + return (11, 11) + + def get_input_and_output(self): + dtype = self.get_dtype() + shape = self.get_shape() + mode = self.get_mode() + assert mode != "r", "Cannot be backward in r mode." + a = np.random.rand(*shape).astype(dtype) + # a = np.array([[1, 2, 3, 2], [4, 5, 6, 3], [7, 8, 7, 4]]).astype(dtype) + # a = np.array([[1, 2], [3, 4], [5, 6]]).astype(dtype) + # a = np.array([[[1, 2], [3, 4], [5, 6]], [[1, 2], [3, 4], [5, 6]]]).astype(dtype) + m = a.shape[-2] + n = a.shape[-1] + min_mn = min(m, n) + if mode == "reduced": + k = min_mn + else: + k = m + q_shape = list(a.shape[:-2]) + q_shape.extend([m, k]) + r_shape = list(a.shape[:-2]) + r_shape.extend([k, n]) + q = np.zeros(q_shape).astype(dtype) + r = np.zeros(r_shape).astype(dtype) + batch_size = a.size // (a.shape[-1] * a.shape[-2]) + for i in range(batch_size): + coord = np.unravel_index(i, a.shape[:-2]) + tmp_q, tmp_r = np.linalg.qr(a[coord], mode=mode) + q[coord] = tmp_q + r[coord] = tmp_r + return a, q, r + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + # dQ = np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]]).astype(self.dtype) + # dR = np.array([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]).astype(self.dtype) + # dQ = np.array([[1, 1], [1, 1], [1, 1]]).astype(self.dtype) + # dR = np.array([[1, 1], [1, 1]]).astype(self.dtype) + # dQ = np.array([[[1, 1], [1, 1], [1, 1]], [[1, 1], [1, 1], [1, 1]]]).astype(self.dtype) + # dR = np.array([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]).astype(self.dtype) + # self.check_grad(['X'], ['Q', 'R'], user_defined_grad_outputs=[dQ, dR]) + self.check_grad(['X'], ['Q', 'R']) + + +class TestQrOpCase1(TestQrOp): + def get_shape(self): + return (10, 12) + + +class TestQrOpCase2(TestQrOp): + def get_shape(self): + return (16, 15) + + +class TestQrOpCase3(TestQrOp): + def get_shape(self): + return (2, 12, 16) + + +class TestQrOpCase4(TestQrOp): + def get_shape(self): + return (3, 16, 15) + + +class TestQrOpCase5(TestQrOp): + def get_mode(self): + return "complete" + + def get_shape(self): + return (10, 12) + + +class TestQrOpCase6(TestQrOp): + def get_mode(self): + return "complete" + + def get_shape(self): + return (2, 10, 12) class TestQrAPI(unittest.TestCase): From 1c590d585cf29f5b7ef6324c4b693a9691c3e39d Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Mon, 10 Jan 2022 04:06:25 +0000 Subject: [PATCH 02/11] Remove unnecessary comments --- python/paddle/fluid/tests/unittests/test_qr_op.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_qr_op.py b/python/paddle/fluid/tests/unittests/test_qr_op.py index b0d072c6ad342..4be46837a67ae 100644 --- a/python/paddle/fluid/tests/unittests/test_qr_op.py +++ b/python/paddle/fluid/tests/unittests/test_qr_op.py @@ -28,7 +28,6 @@ class TestQrOp(OpTest): def setUp(self): paddle.enable_static() np.random.seed(4) - # self._cpu_only = True self.op_type = "qr" a, q, r = self.get_input_and_output() self.inputs = {"X": a} @@ -50,9 +49,6 @@ def get_input_and_output(self): mode = self.get_mode() assert mode != "r", "Cannot be backward in r mode." a = np.random.rand(*shape).astype(dtype) - # a = np.array([[1, 2, 3, 2], [4, 5, 6, 3], [7, 8, 7, 4]]).astype(dtype) - # a = np.array([[1, 2], [3, 4], [5, 6]]).astype(dtype) - # a = np.array([[[1, 2], [3, 4], [5, 6]], [[1, 2], [3, 4], [5, 6]]]).astype(dtype) m = a.shape[-2] n = a.shape[-1] min_mn = min(m, n) @@ -78,13 +74,6 @@ def test_check_output(self): self.check_output() def test_check_grad_normal(self): - # dQ = np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]]).astype(self.dtype) - # dR = np.array([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]).astype(self.dtype) - # dQ = np.array([[1, 1], [1, 1], [1, 1]]).astype(self.dtype) - # dR = np.array([[1, 1], [1, 1]]).astype(self.dtype) - # dQ = np.array([[[1, 1], [1, 1], [1, 1]], [[1, 1], [1, 1], [1, 1]]]).astype(self.dtype) - # dR = np.array([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]).astype(self.dtype) - # self.check_grad(['X'], ['Q', 'R'], user_defined_grad_outputs=[dQ, dR]) self.check_grad(['X'], ['Q', 'R']) @@ -270,5 +259,4 @@ def run_qr_static(shape, mode, dtype): if __name__ == "__main__": - paddle.enable_static() unittest.main() From 46c14667250830e5bb19a415f555420f5344b71f Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Tue, 18 Jan 2022 04:18:15 +0000 Subject: [PATCH 03/11] [Auto Parallel] Improve the dist op interface and compatible computation --- .../distributed/auto_parallel/completion.py | 72 +-- .../distributed/auto_parallel/dist_context.py | 92 ++- .../distributed/auto_parallel/dist_op.py | 5 +- .../auto_parallel/operators/__init__.py | 2 + .../auto_parallel/operators/common.py | 188 ++++-- .../dist_check_finite_and_unscale.py | 13 +- .../auto_parallel/operators/dist_default.py | 151 ++++- .../auto_parallel/operators/dist_embedding.py | 35 +- .../auto_parallel/operators/dist_matmul.py | 561 ++++-------------- .../auto_parallel/operators/dist_reshape.py | 51 +- .../auto_parallel/operators/dist_softmax.py | 37 +- .../auto_parallel/operators/dist_transpose.py | 21 +- .../operators/dist_update_loss_scaling.py | 14 +- .../distributed/auto_parallel/partitioner.py | 34 +- .../distributed/auto_parallel/planner.py | 8 +- .../distributed/auto_parallel/reshard.py | 1 + .../paddle/distributed/auto_parallel/rules.py | 94 +++ .../tests/unittests/test_auto_parallel_api.py | 4 +- .../test_auto_parallel_completion.py | 31 +- .../unittests/test_auto_parallel_mapper.py | 18 +- .../unittests/test_auto_parallel_searcher.py | 4 +- .../test_auto_search_dist_matmul_op.py | 14 +- .../unittests/test_auto_search_dist_op.py | 87 +-- 23 files changed, 764 insertions(+), 773 deletions(-) create mode 100644 python/paddle/distributed/auto_parallel/rules.py diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index b03858119296e..0188923ffd0d3 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -353,30 +353,18 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): compatible_dims_mapping) changed = True # Find the most compatible implemenetations from the distributed operator - op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl( - op_desc.type(), dist_op, fwd=True) - if op_dist_impl is not None: - dim_changed = op_dist_impl.update_dims_mapping(dist_op) - if dim_changed: - changed = True - # This statement will be replaced by a good way - if op_dist_impl.is_compatible(dist_op): - op_dist_attr.impl_type = op_desc.type() - op_dist_attr.impl_idx = op_dist_impl_idx - elif is_elementwise_like_op(op_desc.type()): - dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl( - dist_context, op_node) - if dim_changed: - changed = True - op_dist_attr.impl_type = "element-wise" - op_dist_attr.impl_idx = -1 - else: - dim_changed = update_op_dims_mapping_by_default_dist_impl( - dist_context, op_node) - if dim_changed: - changed = True - op_dist_attr.impl_type = "default" - op_dist_attr.impl_idx = -2 + op_dist_impl = find_best_compatible_distributed_operator_impl( + dist_op, fwd=True) + assert op_dist_impl is not None, "Cannot find the dist op implementation." + dim_changed = op_dist_impl.update_dims_mapping(dist_op) + if dim_changed: + changed = True + if op_dist_impl.is_auto_compatible(dist_op): + if op_dist_impl.type == "elementwise": + op_dist_attr.impl_type = "default" + else: + op_dist_attr.impl_type = op_dist_impl.type + op_dist_attr.impl_idx = op_dist_impl.idx else: for tensor_node in op_node.outputs: if tensor_node.var() is not None: @@ -399,30 +387,18 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): tensor_desc.name(), compatible_dims_mapping) changed = True # Find the most compatible implemenetations from the distributed operator - op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl( - op_desc.type(), dist_op, fwd=False) - if op_dist_impl is not None: - dim_changed = op_dist_impl.update_dims_mapping(dist_op) - if dim_changed: - changed = True - # This statement will be replaced by a good way - if op_dist_impl.is_compatible(dist_op): - op_dist_attr.impl_type = op_desc.type() - op_dist_attr.impl_idx = op_dist_impl_idx - elif is_elementwise_like_op(op_desc.type()): - dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl( - dist_context, op_node) - if dim_changed: - changed = True - op_dist_attr.impl_type = "element-wise" - op_dist_attr.impl_idx = -1 - else: - dim_changed = update_op_dims_mapping_by_default_dist_impl( - dist_context, op_node) - if dim_changed: - changed = True - op_dist_attr.impl_type = "default" - op_dist_attr.impl_idx = -2 + op_dist_impl = find_best_compatible_distributed_operator_impl( + dist_op, fwd=False) + assert op_dist_impl is not None, "Cannot find the dist op implementation." + dim_changed = op_dist_impl.update_dims_mapping(dist_op) + if dim_changed: + changed = True + if op_dist_impl.is_auto_compatible(dist_op): + if op_dist_impl.type == "elementwise": + op_dist_attr.impl_type = "default" + else: + op_dist_attr.impl_type = op_dist_impl.type + op_dist_attr.impl_idx = op_dist_impl.idx return changed diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index b194bcc3de6b5..2d3bfebb37640 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -61,6 +61,8 @@ def __init__(self, program=None): # Other data members self._dist_op_context = DistributedOperatorContext() self._process_meshes = [] + self._serial_ordered_nodes = [] + self._tensor_id_to_tensor_node_ids = {} # Distributed programs self._dist_main_programs = {} @@ -80,6 +82,10 @@ def serial_program(self, program): "This distributed context has already been realted to a serial program" self._serial_program = program + @property + def serial_ordered_nodes(self): + return self._serial_ordered_nodes + @property def process_meshes(self): return self._process_meshes @@ -179,6 +185,18 @@ def get_tensor_dist_attr_for_graph(self, serial_tensor_node): else: return None + def set_tensor_dist_attr_for_graph(self, serial_tensor_node, dist_attr): + assert serial_tensor_node.is_var() and \ + serial_tensor_node.var() is not None + serial_tensor_id = serial_tensor_node.node.original_desc_id() + dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None) + assert dist_tensor is not None, \ + "The distributed tensor of the program has not been added to this context." + serial_tensor_node_id = serial_tensor_node.id() + new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor, + dist_attr) + self._dist_tensors_for_graph[serial_tensor_node_id] = new_dist_tensor + def get_op_dist_attr_for_program(self, serial_op): serial_op_id = serial_op.desc.id() dist_op = self._dist_ops_for_program.get(serial_op_id, None) @@ -204,6 +222,35 @@ def get_op_dist_attr_for_graph(self, serial_op_node): else: return None + def set_op_dist_attr_for_graph(self, serial_op_node, dist_attr): + assert serial_op_node.is_op() and \ + serial_op_node.op() is not None + serial_op_id = serial_op_node.node.original_desc_id() + dist_op = self._dist_ops_for_program.get(serial_op_id, None) + assert dist_op is not None, \ + "The distributed operator of the program has not been added to this context." + serial_op_node_id = serial_op_node.id() + new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr) + self._dist_ops_for_graph[serial_op_node_id] = new_dist_op + + def get_dist_attr_for_graph(self, serial_node): + if serial_node.is_var() and serial_node.var() is not None: + serial_tensor_node_id = serial_node.id() + dist_tensor = self._dist_tensors_for_graph.get( + serial_tensor_node_id, None) + if dist_tensor: + return dist_tensor.dist_attr + else: + return None + if serial_node.is_op() and serial_node.op() is not None: + serial_op_node_id = serial_node.id() + dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None) + if dist_op: + return dist_op.dist_attr + else: + return None + return None + def init_dist_attr_for_program(self): assert self._serial_program, \ "Please set the program of this context before initializing its distribute attributes." @@ -234,6 +281,44 @@ def init_dist_attr_for_program(self): self.add_dist_op_for_program(dist_op) self._is_initialized_for_program = True + def order_nodes_by_program_order(self): + def _contains(nodes, target_node): + for node in nodes: + if node.id() == target_node.id(): + return True + return False + + ordered_tensor_nodes = [] + ordered_op_nodes = [] + all_nodes = self._serial_graph.all_nodes() + for node in all_nodes: + if node.is_var() and node.var() is not None: + ordered_tensor_nodes.append(node) + if node.is_op() and node.op() is not None: + ordered_op_nodes.append(node) + ordered_tensor_nodes.sort(key=lambda node: node.node.original_desc_id()) + ordered_op_nodes.sort(key=lambda node: node.node.original_desc_id()) + for op_node in ordered_op_nodes: + tensor_nodes = [] + for tensor_node in op_node.inputs: + if tensor_node.is_var() \ + and tensor_node.var() is not None \ + and not _contains(self._serial_ordered_nodes, tensor_node): + tensor_nodes.append(tensor_node) + tensor_nodes.sort(key=lambda node: node.node.original_desc_id()) + self._serial_ordered_nodes.extend(tensor_nodes) + self._serial_ordered_nodes.append(op_node) + tensor_nodes = [] + for tensor_node in op_node.outputs: + if tensor_node.is_var() \ + and tensor_node.var() is not None \ + and not _contains(self._serial_ordered_nodes, tensor_node): + tensor_nodes.append(tensor_node) + self._serial_ordered_nodes.extend(tensor_nodes) + num_nodes_before = len(ordered_tensor_nodes) + len(ordered_op_nodes) + assert len(self._serial_ordered_nodes) == num_nodes_before, \ + "The number of nodes before ordering is not the same after ordering." + def init_dist_attr_for_graph(self): assert self._is_initialized_for_program, \ "The program must be initialized before initializing the distributed attributes for its graph." @@ -243,7 +328,8 @@ def init_dist_attr_for_graph(self): self._serial_graph = framework.IrGraph( core.Graph(self._serial_program.desc)) all_nodes = self._serial_graph.all_nodes() - for node in all_nodes: + self.order_nodes_by_program_order() + for node in self.serial_ordered_nodes: if node.is_var() and node.var() is not None: dist_tensor = None tensor_id = node.node.original_desc_id() @@ -383,7 +469,9 @@ def __deepcopy__(self, memo): result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k == "_serial_program" or k == "_serial_graph" or k == "_dist_main_programs" or k == "_dist_startup_programs": + if k == "_serial_program" or k == "_serial_graph" \ + or k == "_dist_main_programs" or k == "_dist_startup_programs" \ + or k == "_serial_ordered_nodes": setattr(result, k, v) else: setattr(result, k, copy.deepcopy(v, memo)) diff --git a/python/paddle/distributed/auto_parallel/dist_op.py b/python/paddle/distributed/auto_parallel/dist_op.py index ef595e2a00f2e..cce047f80555f 100644 --- a/python/paddle/distributed/auto_parallel/dist_op.py +++ b/python/paddle/distributed/auto_parallel/dist_op.py @@ -98,7 +98,7 @@ def _init_default_dist_attr(self): if self._dist_attr.impl_type is None: self._dist_attr.impl_type = "default" if self._dist_attr.impl_idx is None: - self._dist_attr.impl_idx = -2 + self._dist_attr.impl_idx = 0 def _filter_dist_attr(self, dist_attr): if dist_attr is None: @@ -215,7 +215,8 @@ def __str__(self): str += ", pipeline stage: {}".format(None) - str += ", dist_impl idx: {} }}".format(self.dist_attr._impl_idx) + str += ", dist_impl idx: {} , dist_impl type {} }}".format( + self.dist_attr._impl_idx, self.dist_attr._impl_type) return str diff --git a/python/paddle/distributed/auto_parallel/operators/__init__.py b/python/paddle/distributed/auto_parallel/operators/__init__.py index c28b7930124dd..45854052dda4d 100644 --- a/python/paddle/distributed/auto_parallel/operators/__init__.py +++ b/python/paddle/distributed/auto_parallel/operators/__init__.py @@ -22,6 +22,8 @@ from . import dist_reshape from . import dist_softmax from . import dist_transpose +from . import dist_eltwise +from . import dist_split from . import dist_default from . import dist_check_finite_and_unscale from . import dist_update_loss_scaling diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index 8f1ba33f544fb..99fbe7a65f2b0 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -12,53 +12,124 @@ # See the License for the specific language governing permissions and # limitations under the License +import abc from ..dist_attribute import OperatorDistributedAttribute -_g_distributed_operator_impl_registries = {} +_g_distributed_operator_impl_containers = {} + +_g_elementwise_ops = ["elementwise_add", "gelu", "dropout", "cast"] BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'} +def is_elementwise_op(op_type): + if op_type in _g_elementwise_ops: + return True + else: + return False + + class DistributedOperatorImplContainer: - def __init__(self): + def __init__(self, op_type): + self._type = op_type self._impls = [] - self._name = None + + @property + def type(self): + return self._type + + @type.setter + def type(self, op_type): + self._type = op_type + + @property + def impls(self): + return self._impls def register_impl(self, dist_impl): + assert self.type == dist_impl.type, \ + "Op type of container must be same as that of the implementation." + impl_idx = len(self.impls) + dist_impl.idx = impl_idx self._impls.append(dist_impl) def get_impl(self, impl_idx): return self._impls[impl_idx] - def get_impls(self): - return self._impls - + def get_input_compatible_impls(self, dist_op): + compatible_impls = [] + for impl in self.impls: + if impl.is_input_compatible(dist_op): + compatible_impls.append(impl) + return compatible_impls -class DistributedOperatorImpl: - def __init__(self): - self._name = None + def get_output_compatible_impls(self, dist_op): + compatible_impls = [] + for impl in self.impls: + if impl.is_output_compatible(dist_op): + compatible_impls.append(impl) + return compatible_impls + + def get_compatible_impls(self, dist_op): + compatible_impls = [] + for impl in self.impls: + if impl.is_auto_compatible(dist_op): + compatible_impls.append(impl) + return compatible_impls + + +class DistributedOperatorImpl(abc.ABC): + def __init__(self, name): + self._name = name + self._type = None + self._idx = None self._forward_implemented = False self._backward_implemented = False - @staticmethod - def forward(dist_ctx, *args, **kwargs): - raise NotImplementedError("Please Implement this method in Subclass.") + @property + def name(self): + return self._name - @staticmethod - def backward(dist_ctx, *grad_outputs, **kwargs): - raise NotImplementedError("Please Implement this method in Subclass.") + @name.setter + def name(self, name): + self._name = name - def get_name(self): - return self._name + @property + def type(self): + return self._type + @type.setter + def type(self, op_type): + self._type = op_type + + @property + def idx(self): + return self._idx + + @idx.setter + def idx(self, impl_idx): + self._idx = impl_idx + + @abc.abstractmethod def is_input_compatible(self, dist_op): raise NotImplementedError("Please Implement this method in Subclass.") + @abc.abstractmethod def is_output_compatible(self, dist_op): raise NotImplementedError("Please Implement this method in Subclass.") - def is_compatible(self, dist_op): - return self.is_input_compatible(dist_op) and \ - self.is_output_compatible(dist_op) + @abc.abstractmethod + def is_auto_compatible(self, dist_op): + raise NotImplementedError("Please Implement this method in Subclass.") + + @staticmethod + @abc.abstractmethod + def forward(dist_ctx, *args, **kwargs): + raise NotImplementedError("Please Implement this method in Subclass.") + + @staticmethod + @abc.abstractmethod + def backward(dist_ctx, *grad_outputs, **kwargs): + raise NotImplementedError("Please Implement this method in Subclass.") def is_auto_compatible(self, dist_op): raise NotImplementedError("Please Implement this method in Subclass.") @@ -67,54 +138,73 @@ def update_dims_mapping(self, dist_op): raise NotImplementedError("Please Implement this method in Subclass.") -def register_distributed_operator_impl_container(name, dist_op_impl_container): - global _g_distributed_operator_impl_registries - _g_distributed_operator_impl_registries[name] = dist_op_impl_container +def register_distributed_operator_impl_container(container): + global _g_distributed_operator_impl_containers + _g_distributed_operator_impl_containers[container.type] = container -def get_distributed_operator_impl_container(name): - global _g_distributed_operator_impl_registries - return _g_distributed_operator_impl_registries.get(name, None) +def get_distributed_operator_impl_container(op_type): + global _g_distributed_operator_impl_containers + return _g_distributed_operator_impl_containers.get(op_type, None) -def register_distributed_operator_impl(name, dist_impl): - dist_op_impl_container = get_distributed_operator_impl_container(name) +def register_distributed_operator_impl(op_type, dist_impl): + dist_op_impl_container = get_distributed_operator_impl_container(op_type) if dist_op_impl_container is not None: + dist_impl.type = op_type dist_op_impl_container.register_impl(dist_impl) else: assert False, "Must register distributed operator registry first." -def get_distributed_operator_impl(name, impl_idx): - global _g_distributed_operator_impl_registries - return _g_distributed_operator_impl_registries[name].get_impl(impl_idx) - - -def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True): +def find_best_compatible_distributed_operator_impl(dist_op, fwd=True): """ Here just return the first compatible implemention. This will be improved by cost model in the future. """ - dist_op_impl_container = get_distributed_operator_impl_container(name) - if dist_op_impl_container is None: - return None, -1 + op_type = dist_op.serial_op.type + dist_op_impl_container = get_distributed_operator_impl_container(op_type) + dist_op_eltwise_impl_container = get_distributed_operator_impl_container( + "elementwise") + dist_op_default_impl_container = get_distributed_operator_impl_container( + "default") compatible_impls = [] - impls = dist_op_impl_container.get_impls() if fwd: - for idx, impl in enumerate(impls): - if impl.is_input_compatible(dist_op): - compatible_impls.append((impl, idx)) + # First, find impls in the corresponding container + if dist_op_impl_container: + compatible_impls.extend( + dist_op_impl_container.get_input_compatible_impls(dist_op)) + # Second, find impls in the elementwise container + if dist_op_eltwise_impl_container and is_elementwise_op(op_type): + compatible_impls.extend( + dist_op_eltwise_impl_container.get_input_compatible_impls( + dist_op)) + # Third, find impls in the default container + if dist_op_default_impl_container: + compatible_impls.extend( + dist_op_default_impl_container.get_input_compatible_impls( + dist_op)) else: - for idx, impl in enumerate(impls): - if impl.is_output_compatible(dist_op): - compatible_impls.append((impl, idx)) - + # First, find impls in the corresponding container + if dist_op_impl_container: + compatible_impls.extend( + dist_op_impl_container.get_output_compatible_impls(dist_op)) + # Second, find impls in the elementwise container + if dist_op_eltwise_impl_container and is_elementwise_op(op_type): + compatible_impls.extend( + dist_op_eltwise_impl_container.get_output_compatible_impls( + dist_op)) + # Third, find impls in the default container + if dist_op_default_impl_container: + compatible_impls.extend( + dist_op_default_impl_container.get_output_compatible_impls( + dist_op)) if compatible_impls: - best_compatible_impl, idx = compatible_impls[0] + # For now, just return the first compatible impl + best_compatible_impl = compatible_impls[0] else: - best_compatible_impl, idx = None, -1 - - return best_compatible_impl, idx + best_compatible_impl = None + return best_compatible_impl def is_parameter_related(varname, block): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py b/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py index 00dc346f9a2ac..52d5e85c962eb 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py @@ -30,19 +30,17 @@ class DistributedCheckFiniteAndUnscale(DistributedOperatorImplContainer): - def __init__(self, name): - super(DistributedCheckFiniteAndUnscale, self).__init__() - self._name = name + def __init__(self, op_type): + super(DistributedCheckFiniteAndUnscale, self).__init__(op_type) register_distributed_operator_impl_container( - "check_finite_and_unscale", DistributedCheckFiniteAndUnscale("check_finite_and_unscale")) class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): def __init__(self, name): - super(DistributedCheckFiniteAndUnscaleImpl, self).__init__() + super(DistributedCheckFiniteAndUnscaleImpl, self).__init__(name) self._name = name self._forward_implemented = False self._backward_implemented = True @@ -57,6 +55,11 @@ def is_output_compatible(self, dist_op): "DistributedCheckFiniteAndUnscaleImpl's is_output_compatible should not be called !" ) + def is_auto_compatible(self, dist_op): + raise RuntimeError( + "DistributedCheckFiniteAndUnscaleImpl's is_auto_compatible should not be called !" + ) + def update_dims_mapping(self, dist_op): raise RuntimeError( "DistributedCheckFiniteAndUnscaleImpl's update_dims_mapping should not be called !" diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index 1a3d57bf140dd..d186d3640f50c 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -34,31 +34,162 @@ class DistributedDefault(DistributedOperatorImplContainer): - def __init__(self, name): - super(DistributedDefault, self).__init__() - self._name = name + def __init__(self, op_type): + super(DistributedDefault, self).__init__(op_type) -register_distributed_operator_impl_container("default", - DistributedDefault("default")) +register_distributed_operator_impl_container(DistributedDefault("default")) # Replicated Default class DistributedDefaultImpl0(DistributedOperatorImpl): def __init__(self, name): - super(DistributedDefaultImpl0, self).__init__() - self._name = name + super(DistributedDefaultImpl0, self).__init__(name) self._forward_implemented = True self._backward_implemented = True def is_input_compatible(self, dist_op): - raise NotImplementedError("Please Implement this method.") + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + for arg_name in op_desc.input_arg_names(): + serial_tensor = dist_op.get_serial_input(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + if len(dims_mapping) > 1: + for mapping in dims_mapping[1:]: + if mapping != -1: + return False + return True def is_output_compatible(self, dist_op): - raise NotImplementedError("Please Implement this method.") + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + output_names = op_desc.output_names() + xshape_arg_names = [] + if "XShape" in output_names: + xshape_arg_names = op_desc.output("XShape") + for arg_name in op_desc.output_arg_names(): + serial_tensor = dist_op.get_serial_output(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if arg_name not in xshape_arg_names: + if len(dims_mapping) > 1: + for mapping in dims_mapping[1:]: + if mapping != -1: + return False + else: + if dims_mapping[0] != -1: + return False + if len(dims_mapping) > 2: + for mapping in dims_mapping[2:]: + if mapping != -1: + return False + return True + + def is_auto_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + batch_dim_mappings = [] + # Check input compatibility + for arg_name in op_desc.input_arg_names(): + serial_tensor = dist_op.get_serial_input(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + if len(dims_mapping) > 1: + for mapping in dims_mapping[1:]: + if mapping != -1: + return False + batch_dim_mappings.append(dims_mapping[0]) + + # Check output compatibility + output_names = op_desc.output_names() + xshape_arg_names = [] + if "XShape" in output_names: + xshape_arg_names = op_desc.output("XShape") + for arg_name in op_desc.output_arg_names(): + serial_tensor = dist_op.get_serial_output(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if arg_name not in xshape_arg_names: + if len(dims_mapping) > 1: + for mapping in dims_mapping[1:]: + if mapping != -1: + return False + batch_dim_mappings.append(dims_mapping[0]) + else: + if dims_mapping[0] != -1: + return False + if len(dims_mapping) > 2: + for mapping in dims_mapping[2:]: + if mapping != -1: + return False + batch_dim_mappings.append(dims_mapping[1]) + + # Check batch dim mapping compatibility + if not all(batch_dim_mappings[0] == dim_mapping + for dim_mapping in batch_dim_mappings): + return False + + return True def update_dims_mapping(self, dist_op): - raise NotImplementedError("Please Implement this method.") + changed = False + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + # The following statement will be replaced by a more elegent way + if op_desc.type() == "shape" or op_desc.type() == "slice": + return False + output_names = op_desc.output_names() + xshape_arg_names = [] + if "XShape" in output_names: + xshape_arg_names = op_desc.output("XShape") + batch_dim_mappings = [] + for arg_name in op_desc.input_arg_names(): + serial_tensor = dist_op.get_serial_input(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + batch_dim_mappings.append(dims_mapping[0]) + for arg_name in op_desc.output_arg_names(): + serial_tensor = dist_op.get_serial_output(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if arg_name not in xshape_arg_names: + batch_dim_mappings.append(dims_mapping[0]) + else: + batch_dim_mappings.append(dims_mapping[1]) + + compatible_dim_mapping = compute_compatible_dim_mapping( + batch_dim_mappings) + assert compatible_dim_mapping is not None, "There is no compatible dim mapping." + for arg_name in op_desc.input_arg_names(): + serial_tensor = dist_op.get_serial_input(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + if compatible_dim_mapping != dims_mapping[0]: + dims_mapping[0] = compatible_dim_mapping + changed = True + for arg_name in op_desc.output_arg_names(): + serial_tensor = dist_op.get_serial_output(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if arg_name not in xshape_arg_names: + if compatible_dim_mapping != dims_mapping[0]: + dims_mapping[0] = compatible_dim_mapping + changed = True + else: + if compatible_dim_mapping != dims_mapping[1]: + dims_mapping[1] = compatible_dim_mapping + changed = True + + return changed @staticmethod def forward(ctx, *args, **kwargs): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index 866fed1ae6067..1d3c047cac2e8 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -34,22 +34,20 @@ class DistributedEmbedding(DistributedOperatorImplContainer): - def __init__(self, name): - super(DistributedEmbedding, self).__init__() - self._name = name + def __init__(self, op_type): + super(DistributedEmbedding, self).__init__(op_type) -register_distributed_operator_impl_container("lookup_table_v2", - DistributedEmbedding("embedding")) -register_distributed_operator_impl_container("c_embedding", - DistributedEmbedding("embedding")) +register_distributed_operator_impl_container( + DistributedEmbedding("lookup_table_v2")) +register_distributed_operator_impl_container( + DistributedEmbedding("c_embedding")) # RowParallel class DistributedEmbeddingImpl(DistributedOperatorImpl): def __init__(self, name): - super(DistributedEmbeddingImpl, self).__init__() - self._name = name + super(DistributedEmbeddingImpl, self).__init__(name) self._forward_implemented = True self._backward_implemented = True @@ -81,6 +79,10 @@ def is_output_compatible(self, dist_op): return True def is_auto_compatible(self, dist_op): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): + return False + op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr ids_name = op_desc.input('Ids')[0] @@ -89,18 +91,7 @@ def is_auto_compatible(self, dist_op): out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name) w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name) - if is_dim_replicate(w_dims_mapping[-2]) or is_dim_shard(w_dims_mapping[ - -1]): - return False - # Other dimensions must be replicate except the batch dimension - for mapping in ids_dims_mapping[1:]: - if is_dim_shard(mapping): - return False - for mapping in out_dims_mapping[1:]: - if is_dim_shard(mapping): - return False - if w_dims_mapping[-1] != out_dims_mapping[-1]: - return False + if ids_dims_mapping != out_dims_mapping[:len(ids_dims_mapping)]: return False @@ -248,6 +239,7 @@ def forward(ctx, *args, **kwargs): # matmulv2 embedding_op_dist_attr = OperatorDistributedAttribute() embedding_op_dist_attr.process_mesh = op_dist_attr.process_mesh + embedding_op_dist_attr.impl_type = op_dist_attr.impl_type embedding_op_dist_attr.impl_idx = op_dist_attr.impl_idx for input_varname in c_embedding_op.desc.input_arg_names(): input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) @@ -266,6 +258,7 @@ def forward(ctx, *args, **kwargs): # allreduce allreduce_op_dist_attr = OperatorDistributedAttribute() allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh + allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx for input_varname in c_allreduce_sum_op.desc.input_arg_names(): input_var = main_block.var(input_varname) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index f4c31c3654c52..e365fcf52dff2 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -34,6 +34,7 @@ from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from ..process_group import new_process_group from ..utils import _get_comm_group, _get_corresponding_rank +from .dist_default import DistributedDefaultImpl0 def copy_op_with_new_input_output(ctx, block, src_op, **kwargs): @@ -143,6 +144,68 @@ def _update_dims_mapping_for_matmul(dist_op): return changed +def _is_auto_compatible_for_matmul(dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + y_name = op_desc.input('Y')[0] + out_name = op_desc.output('Out')[0] + # Deep copy these dims_mappings for keeping them unchanged. + x_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(x_name)) + y_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(y_name)) + out_dims_mapping = copy.deepcopy( + op_dist_attr.get_output_dims_mapping(out_name)) + x_dims_mapping_len = len(x_dims_mapping) + y_dims_mapping_len = len(y_dims_mapping) + out_dims_mapping_len = len(out_dims_mapping) + + # Add dim mapping to Make sure the length dims_mapping be at least 2 + if x_dims_mapping_len == 1: + x_dims_mapping.insert(0, -1) + if y_dims_mapping_len == 1: + y_dims_mapping.insert(1, -1) + + # Deal with dim > 2 and take care of broadcasting + if out_dims_mapping_len > 2: + broadcast_x_dims_mapping = [] + broadcast_y_dims_mapping = [] + broadcast_out_dims_mapping = [] + + for i in range(out_dims_mapping_len - x_dims_mapping_len): + broadcast_x_dims_mapping.append(out_dims_mapping[i]) + for i in range(x_dims_mapping_len - 2): + broadcast_x_dims_mapping.append(x_dims_mapping[i]) + + for i in range(out_dims_mapping_len - y_dims_mapping_len): + broadcast_y_dims_mapping.append(out_dims_mapping[i]) + for i in range(y_dims_mapping_len - 2): + broadcast_y_dims_mapping.append(y_dims_mapping[i]) + + for i in range(out_dims_mapping_len - 2): + broadcast_out_dims_mapping.append(out_dims_mapping[i]) + + is_same = ((broadcast_x_dims_mapping == broadcast_y_dims_mapping) and + (broadcast_x_dims_mapping == broadcast_out_dims_mapping)) + if not is_same: + return False + + # The following which uses negative index can be work + # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2 + is_same = (x_dims_mapping[-1] == y_dims_mapping[-2]) + if not is_same: + return False + + is_same = (x_dims_mapping[-2] == out_dims_mapping[-2]) + if not is_same: + return False + + is_same = (y_dims_mapping[-1] == out_dims_mapping[-1]) + if not is_same: + return False + + return True + + def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): # by now the backward function only insert the gradient allreduce for dist op itself @@ -194,10 +257,10 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name) process_mesh_shape = dist_attr.process_mesh.topology process_mesh_group = dist_attr.process_mesh.processes - assert len( - Y_var_dim_mapping - ) == 2, "dist matmual only support Y operand with 2 dims now but Y({})'s dim is [{}]".format( - Y_var.name, Y_var_dim_mapping) + # assert len( + # Y_var_dim_mapping + # ) == 2, "dist matmual only support Y operand with 2 dims now but Y({})'s dim is [{}]".format( + # Y_var.name, Y_var_dim_mapping) Y_var_partitioned = False for dim in Y_var_dim_mapping: if dim >= 0 and process_mesh_shape[dim] > 0: @@ -388,20 +451,17 @@ def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id): class DistributedMatmul(DistributedOperatorImplContainer): - def __init__(self, name): - super(DistributedMatmul, self).__init__() - self._name = name + def __init__(self, op_type): + super(DistributedMatmul, self).__init__(op_type) -register_distributed_operator_impl_container("matmul", - DistributedMatmul("matmul")) +register_distributed_operator_impl_container(DistributedMatmul("matmul")) # ColumnParallel class DistributedMatmulImpl0(DistributedOperatorImpl): def __init__(self, name): - super(DistributedMatmulImpl0, self).__init__() - self._name = name + super(DistributedMatmulImpl0, self).__init__(name) self._forward_implemented = True self._backward_implemented = True @@ -414,8 +474,8 @@ def is_input_compatible(self, dist_op): y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) if is_dim_shard(x_dims_mapping[-1]): return False - if is_dim_shard(y_dims_mapping[0]) or is_dim_replicate(y_dims_mapping[ - 1]): + if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(y_dims_mapping[ + -1]): return False for mapping in x_dims_mapping[1:-1]: if is_dim_shard(mapping): @@ -435,83 +495,11 @@ def is_output_compatible(self, dist_op): return True def is_auto_compatible(self, dist_op): - op_desc = dist_op.serial_op.desc - op_dist_attr = dist_op.dist_attr - x_name = op_desc.input('X')[0] - y_name = op_desc.input('Y')[0] - out_name = op_desc.output('Out')[0] - out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) - x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) - y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) - - assert len(x_dims_mapping) >= len( - y_dims_mapping), "now just support x dims > y dims" - if len(y_dims_mapping) != 2: + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): return False - - if len(x_dims_mapping) == len(y_dims_mapping) and len( - x_dims_mapping) == 4: - if x_dims_mapping[:2] != y_dims_mapping[:2]: - return False - if x_dims_mapping[:2] != out_dims_mapping[:2]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - elif len(x_dims_mapping) != len(y_dims_mapping) and len( - x_dims_mapping) == 3: - if x_dims_mapping[0] != out_dims_mapping[0]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - - if is_dim_replicate(out_dims_mapping[-1]): + if not _is_auto_compatible_for_matmul(dist_op): return False - - for mapping in out_dims_mapping[1:-1]: - if is_dim_shard(mapping): - return False - - input_dims_mapping = [] - ordered_input_shard_dims_mapping = [] - - for dim in (x_dims_mapping + y_dims_mapping): - input_dims_mapping.append(dim) - - for item in input_dims_mapping: - if item not in ordered_input_shard_dims_mapping and item != -1: - ordered_input_shard_dims_mapping.append(item) - - for mapping in out_dims_mapping: - if mapping not in input_dims_mapping: - return False - - if is_dim_shard(x_dims_mapping[0]): - order_index = 0 - for idx, item in enumerate(out_dims_mapping): - if item != -1: - if item != ordered_input_shard_dims_mapping[order_index]: - return False - else: - order_index += 1 - if order_index != len(ordered_input_shard_dims_mapping): - return False - - if is_dim_shard(x_dims_mapping[-1]): - return False - if is_dim_shard(y_dims_mapping[0]) or is_dim_replicate(y_dims_mapping[ - 1]): - return False - for mapping in x_dims_mapping[1:-1]: - if is_dim_shard(mapping): - return False - - if is_dim_shard(x_dims_mapping[0]): - for mapping in y_dims_mapping[1:]: - if is_dim_shard(mapping) and mapping == x_dims_mapping[0]: - return False - return True def update_dims_mapping(self, dist_op): @@ -635,6 +623,7 @@ def forward(ctx, *args, **kwargs): # c_identity identity_op_dist_attr = OperatorDistributedAttribute() identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh + identity_op_dist_attr.impl_type = op_dist_attr.impl_type identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx # input input_varname = c_identity_op.desc.input_arg_names()[0] @@ -653,6 +642,7 @@ def forward(ctx, *args, **kwargs): # matmul matmul_op_dist_attr = OperatorDistributedAttribute() matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh + matmul_op_dist_attr.impl_type = op_dist_attr.impl_type matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx # input for input_varname in matmul_op.desc.input_arg_names(): @@ -692,8 +682,7 @@ def backward(ctx, *args, **kwargs): # RowParallel class DistributedMatmulImpl1(DistributedOperatorImpl): def __init__(self, name): - super(DistributedMatmulImpl1, self).__init__() - self._name = name + super(DistributedMatmulImpl1, self).__init__(name) self._forward_implemented = True self._backward_implemented = True @@ -729,93 +718,12 @@ def is_output_compatible(self, dist_op): return True def is_auto_compatible(self, dist_op): - op_desc = dist_op.serial_op.desc - op_dist_attr = dist_op.dist_attr - x_name = op_desc.input('X')[0] - y_name = op_desc.input('Y')[0] - x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) - y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) - - if op_desc.attr('transpose_X') or op_desc.attr('transpose_Y'): - return False - out_name = op_desc.output('Out')[0] - out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) - # for gpt2, x dims > y dims, this is a temporary solution - assert len(x_dims_mapping) >= len( - y_dims_mapping), "now just support x dims > y dims" - if len(y_dims_mapping) != 2: + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): return False - if len(x_dims_mapping) == len(y_dims_mapping) and len( - x_dims_mapping) == 4: - if x_dims_mapping[:2] != y_dims_mapping[:2]: - return False - if x_dims_mapping[:2] != out_dims_mapping[:2]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - elif len(x_dims_mapping) != len(y_dims_mapping) and len( - x_dims_mapping) == 3: - if x_dims_mapping[0] != out_dims_mapping[0]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - if is_dim_shard(out_dims_mapping[-1]): + if not _is_auto_compatible_for_matmul(dist_op): return False - # Other dimensions must be replicate except the batch dimension - for mapping in out_dims_mapping[1:-1]: - if is_dim_shard(mapping): - return False - - if is_dim_replicate(x_dims_mapping[-1]): - return False - - if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(y_dims_mapping[ - -1]): - return False - - # Other dimensions must be replicate except the batch dimension - for mapping in x_dims_mapping[1:-1]: - if is_dim_shard(mapping): - return False - - x_shard_dim_count = 0 - x_shard_dims = [] - y_shard_dim_count = 0 - y_shard_dims = [] - for dim in x_dims_mapping: - if is_dim_shard(dim): - x_shard_dim_count += 1 - x_shard_dims.append(dim) - - for dim in y_dims_mapping: - if is_dim_shard(dim): - y_shard_dim_count += 1 - y_shard_dims.append(dim) - - if not x_shard_dims and not y_shard_dims: - return False - - if x_shard_dims[-1] != y_shard_dims[0]: - return False - - if x_shard_dim_count == y_shard_dim_count: - for dim in out_dims_mapping: - if is_dim_shard(dim): - return False - if x_shard_dims != y_shard_dims: - return False - else: - if x_shard_dim_count < y_shard_dim_count: - return False - output_shard_dims = [] - for dim in out_dims_mapping: - if is_dim_shard(dim): - output_shard_dims.append(dim) - if not output_shard_dims or output_shard_dims[0] != x_shard_dims[0]: - return False return True @@ -933,6 +841,7 @@ def forward(ctx, *args, **kwargs): # matmul matmul_op_dist_attr = OperatorDistributedAttribute() matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh + matmul_op_dist_attr.impl_type = op_dist_attr.impl_type matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx for input_varname in matmul_op.desc.input_arg_names(): input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) @@ -951,6 +860,7 @@ def forward(ctx, *args, **kwargs): # allreduce allreduce_op_dist_attr = OperatorDistributedAttribute() allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh + allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx for input_varname in c_allreduce_sum_op.desc.input_arg_names(): input_var = main_block.var(input_varname) @@ -980,8 +890,7 @@ def backward(ctx, *args, **kwargs): # ReplicateParallel class DistributedMatmulImpl2(DistributedOperatorImpl): def __init__(self, name): - super(DistributedMatmulImpl2, self).__init__() - self._name = name + super(DistributedMatmulImpl2, self).__init__(name) def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc @@ -1020,56 +929,11 @@ def is_output_compatible(self, dist_op): return True def is_auto_compatible(self, dist_op): - op_desc = dist_op.serial_op.desc - op_dist_attr = dist_op.dist_attr - x_name = op_desc.input('X')[0] - y_name = op_desc.input('Y')[0] - out_name = op_desc.output('Out')[0] - out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) - x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) - y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) - assert len(x_dims_mapping) >= len( - y_dims_mapping - ), "now just support x dims > y dims,but x:{0} and y:{1}".format( - x_dims_mapping, y_dims_mapping) - if len(y_dims_mapping) != 2: - return False - if len(x_dims_mapping) == len(y_dims_mapping) and len( - x_dims_mapping) == 4: - if x_dims_mapping[:2] != y_dims_mapping[:2]: - return False - if x_dims_mapping[:2] != out_dims_mapping[:2]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - elif len(x_dims_mapping) != len(y_dims_mapping) and len( - x_dims_mapping) == 3: - if x_dims_mapping[0] != out_dims_mapping[0]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - - if is_dim_shard(out_dims_mapping[-1]): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): return False - if is_valid_list_index(out_dims_mapping, - -2) and is_dim_shard(out_dims_mapping[-2]): - return False - - if is_dim_shard(x_dims_mapping[-1]): - return False - - if is_valid_list_index(x_dims_mapping, - -2) and is_dim_shard(x_dims_mapping[-2]): - return False - - if is_dim_shard(y_dims_mapping[-1]): - return False - - if is_valid_list_index(y_dims_mapping, - -2) and is_dim_shard(y_dims_mapping[-2]): + if not _is_auto_compatible_for_matmul(dist_op): return False return True @@ -1081,6 +945,10 @@ def update_dims_mapping(self, dist_op): changed = True return changed + @staticmethod + def forward(ctx, *args, **kwargs): + DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + @staticmethod def backward(ctx, *args, **kwargs): _right_operand_parameter_matmul_backward(ctx, *args, **kwargs) @@ -1095,20 +963,17 @@ def backward(ctx, *args, **kwargs): class DistributedMatmulV2(DistributedOperatorImplContainer): - def __init__(self, name): - super(DistributedMatmulV2, self).__init__() - self._name = name + def __init__(self, op_type): + super(DistributedMatmulV2, self).__init__(op_type) -register_distributed_operator_impl_container("matmul_v2", - DistributedMatmulV2("matmul_v2")) +register_distributed_operator_impl_container(DistributedMatmulV2("matmul_v2")) # ColumnParallel class DistributedMatmulV2Impl0(DistributedOperatorImpl): def __init__(self, name): - super(DistributedMatmulV2Impl0, self).__init__() - self._name = name + super(DistributedMatmulV2Impl0, self).__init__(name) self._forward_implemented = True self._backward_implemented = True @@ -1121,8 +986,8 @@ def is_input_compatible(self, dist_op): y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) if is_dim_shard(x_dims_mapping[-1]): return False - if is_dim_shard(y_dims_mapping[0]) or is_dim_replicate(y_dims_mapping[ - 1]): + if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(y_dims_mapping[ + -1]): return False for mapping in x_dims_mapping[1:-1]: if is_dim_shard(mapping): @@ -1142,85 +1007,13 @@ def is_output_compatible(self, dist_op): return True def is_auto_compatible(self, dist_op): - op_desc = dist_op.serial_op.desc - op_dist_attr = dist_op.dist_attr - x_name = op_desc.input('X')[0] - y_name = op_desc.input('Y')[0] - out_name = op_desc.output('Out')[0] - out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) - x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) - y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) - - if op_desc.attr('trans_x') or op_desc.attr('trans_y'): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): return False - assert len(x_dims_mapping) >= len( - y_dims_mapping), "now just support x dims > y dims" - if len(y_dims_mapping) != 2: - return False - if len(x_dims_mapping) == len(y_dims_mapping) and len( - x_dims_mapping) == 4: - if x_dims_mapping[:2] != y_dims_mapping[:2]: - return False - if x_dims_mapping[:2] != out_dims_mapping[:2]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - elif len(x_dims_mapping) != len(y_dims_mapping) and len( - x_dims_mapping) == 3: - if x_dims_mapping[0] != out_dims_mapping[0]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - if is_dim_replicate(out_dims_mapping[-1]): + if not _is_auto_compatible_for_matmul(dist_op): return False - for mapping in out_dims_mapping[1:-1]: - if is_dim_shard(mapping): - return False - input_dims_mapping = [] - ordered_input_shard_dims_mapping = [] - - for dim in (x_dims_mapping + y_dims_mapping): - input_dims_mapping.append(dim) - - for item in input_dims_mapping: - if item not in ordered_input_shard_dims_mapping and item != -1: - ordered_input_shard_dims_mapping.append(item) - - for mapping in out_dims_mapping: - if mapping not in input_dims_mapping: - return False - - if is_dim_shard(x_dims_mapping[0]): - order_index = 0 - for idx, item in enumerate(out_dims_mapping): - if item != -1: - if item != ordered_input_shard_dims_mapping[order_index]: - return False - else: - order_index += 1 - if order_index != len(ordered_input_shard_dims_mapping): - return False - - if is_dim_shard(x_dims_mapping[-1]): - return False - - if is_dim_shard(y_dims_mapping[0]) or is_dim_replicate(y_dims_mapping[ - 1]): - return False - - for mapping in x_dims_mapping[1:-1]: - if is_dim_shard(mapping): - return False - - if is_dim_shard(x_dims_mapping[0]): - for mapping in y_dims_mapping[1:]: - if is_dim_shard(mapping) and mapping == x_dims_mapping[0]: - return False - return True def update_dims_mapping(self, dist_op): @@ -1342,6 +1135,7 @@ def forward(ctx, *args, **kwargs): # c_identity identity_op_dist_attr = OperatorDistributedAttribute() identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh + identity_op_dist_attr.impl_type = op_dist_attr.impl_type identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx # input input_varname = c_identity_op.desc.input_arg_names()[0] @@ -1359,6 +1153,7 @@ def forward(ctx, *args, **kwargs): # matmulv2 matmulv2_op_dist_attr = OperatorDistributedAttribute() matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh + matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx for input_varname in matmul_v2_op.desc.input_arg_names(): if input_varname in src_op.desc.input_arg_names(): @@ -1395,8 +1190,7 @@ def backward(ctx, *args, **kwargs): # RowParallel class DistributedMatmulV2Impl1(DistributedOperatorImpl): def __init__(self, name): - super(DistributedMatmulV2Impl1, self).__init__() - self._name = name + super(DistributedMatmulV2Impl1, self).__init__(name) self._forward_implemented = True self._backward_implemented = True @@ -1432,93 +1226,13 @@ def is_output_compatible(self, dist_op): return True def is_auto_compatible(self, dist_op): - op_desc = dist_op.serial_op.desc - op_dist_attr = dist_op.dist_attr - x_name = op_desc.input('X')[0] - y_name = op_desc.input('Y')[0] - x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) - y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) - if op_desc.attr('trans_x') or op_desc.attr('trans_y'): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): return False - out_name = op_desc.output('Out')[0] - out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) - assert len(x_dims_mapping) >= len( - y_dims_mapping), "now just support x dims > y dims" - if len(y_dims_mapping) != 2: - return False - if len(x_dims_mapping) == len(y_dims_mapping) and len( - x_dims_mapping) == 4: - if x_dims_mapping[:2] != y_dims_mapping[:2]: - return False - if x_dims_mapping[:2] != out_dims_mapping[:2]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - - elif len(x_dims_mapping) != len(y_dims_mapping) and len( - x_dims_mapping) == 3: - if x_dims_mapping[0] != out_dims_mapping[0]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - if is_dim_shard(out_dims_mapping[-1]): + if not _is_auto_compatible_for_matmul(dist_op): return False - # Other dimensions must be replicate except the batch dimension - for mapping in out_dims_mapping[1:-1]: - if is_dim_shard(mapping): - return False - - if is_dim_replicate(x_dims_mapping[-1]): - return False - - if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(y_dims_mapping[ - -1]): - return False - - # Other dimensions must be replicate except the batch dimension - for mapping in x_dims_mapping[1:-1]: - if is_dim_shard(mapping): - return False - - x_shard_dim_count = 0 - x_shard_dims = [] - y_shard_dim_count = 0 - y_shard_dims = [] - for dim in x_dims_mapping: - if is_dim_shard(dim): - x_shard_dim_count += 1 - x_shard_dims.append(dim) - - for dim in y_dims_mapping: - if is_dim_shard(dim): - y_shard_dim_count += 1 - y_shard_dims.append(dim) - - if not x_shard_dims and not y_shard_dims: - return False - - if x_shard_dims[-1] != y_shard_dims[0]: - return False - - if x_shard_dim_count == y_shard_dim_count: - for dim in out_dims_mapping: - if is_dim_shard(dim): - return False - if x_shard_dims != y_shard_dims: - return False - else: - if x_shard_dim_count < y_shard_dim_count: - return False - output_shard_dims = [] - for dim in out_dims_mapping: - if is_dim_shard(dim): - output_shard_dims.append(dim) - if not output_shard_dims or output_shard_dims[0] != x_shard_dims[0]: - return False return True def update_dims_mapping(self, dist_op): @@ -1631,6 +1345,7 @@ def forward(ctx, *args, **kwargs): # matmulv2 matmulv2_op_dist_attr = OperatorDistributedAttribute() matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh + matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx for input_varname in matmul_v2_op.desc.input_arg_names(): input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) @@ -1649,6 +1364,7 @@ def forward(ctx, *args, **kwargs): # allreduce allreduce_op_dist_attr = OperatorDistributedAttribute() allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh + allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx for input_varname in c_allreduce_sum_op.desc.input_arg_names(): input_var = main_block.var(input_varname) @@ -1678,8 +1394,7 @@ def backward(ctx, *args, **kwargs): # ReplicateParallel class DistributedMatmulV2Impl2(DistributedOperatorImpl): def __init__(self, name): - super(DistributedMatmulV2Impl2, self).__init__() - self._name = name + super(DistributedMatmulV2Impl2, self).__init__(name) def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc @@ -1720,57 +1435,11 @@ def is_output_compatible(self, dist_op): return True def is_auto_compatible(self, dist_op): - op_desc = dist_op.serial_op.desc - op_dist_attr = dist_op.dist_attr - x_name = op_desc.input('X')[0] - y_name = op_desc.input('Y')[0] - out_name = op_desc.output('Out')[0] - out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) - x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) - y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) - assert len(x_dims_mapping) >= len( - y_dims_mapping - ), "now just support x dims > y dims,but x:{0} and y:{1}".format( - x_dims_mapping, y_dims_mapping) - if len(y_dims_mapping) != 2: - return False - if len(x_dims_mapping) == len(y_dims_mapping) and len( - x_dims_mapping) == 4: - if x_dims_mapping[:2] != y_dims_mapping[:2]: - return False - if x_dims_mapping[:2] != out_dims_mapping[:2]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - - elif len(x_dims_mapping) != len(y_dims_mapping) and len( - x_dims_mapping) == 3: - if x_dims_mapping[0] != out_dims_mapping[0]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - - if is_dim_shard(out_dims_mapping[-1]): - return False - - if is_valid_list_index(out_dims_mapping, - -2) and is_dim_shard(out_dims_mapping[-2]): - return False - - if is_dim_shard(x_dims_mapping[-1]): - return False - - if is_valid_list_index(x_dims_mapping, - -2) and is_dim_shard(x_dims_mapping[-2]): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): return False - if is_dim_shard(y_dims_mapping[-1]): - return False - - if is_valid_list_index(y_dims_mapping, - -2) and is_dim_shard(y_dims_mapping[-2]): + if not _is_auto_compatible_for_matmul(dist_op): return False return True @@ -1782,6 +1451,10 @@ def update_dims_mapping(self, dist_op): changed = True return changed + @staticmethod + def forward(ctx, *args, **kwargs): + DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + @staticmethod def backward(ctx, *args, **kwargs): _right_operand_parameter_matmul_backward(ctx, *args, **kwargs) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py index e287bd75b3589..93b0d91b7836d 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py @@ -27,22 +27,20 @@ from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype +from .dist_default import DistributedDefaultImpl0 class DistributedReshape2(DistributedOperatorImplContainer): - def __init__(self, name): - super(DistributedReshape2, self).__init__() - self._name = name + def __init__(self, op_type): + super(DistributedReshape2, self).__init__(op_type) -register_distributed_operator_impl_container("reshape2", - DistributedReshape2("reshape2")) +register_distributed_operator_impl_container(DistributedReshape2("reshape2")) class DistributedReshapeImpl0(DistributedOperatorImpl): def __init__(self, name): - super(DistributedReshapeImpl0, self).__init__() - self._name = name + super(DistributedReshapeImpl0, self).__init__(name) self._forward_implemented = True self._backward_implemented = False @@ -76,6 +74,10 @@ def is_output_compatible(self, dist_op): return True def is_auto_compatible(self, dist_op): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): + return False + op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] @@ -85,17 +87,10 @@ def is_auto_compatible(self, dist_op): x_shape_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) - if len(x_dims_mapping) != len(out_dims_mapping) - 1: - return False - if is_dim_shard(out_dims_mapping[-1]): - return False - - for idx, item in enumerate(out_dims_mapping[:-2]): - if x_dims_mapping[idx] != item: + for idx, dim_mapping in enumerate(out_dims_mapping[:-1]): + if x_dims_mapping[idx] != dim_mapping: return False - if out_dims_mapping[-2] != x_dims_mapping[-1]: - return False if x_shape_dims_mapping[0] != -1: return False @@ -194,13 +189,12 @@ def forward(ctx, *args, **kwargs): @staticmethod def backward(ctx, *args, **kwargs): - pass + DistributedDefaultImpl0.backward(ctx, *args, **kwargs) class DistributedReshapeImpl1(DistributedOperatorImpl): def __init__(self, name): - super(DistributedReshapeImpl1, self).__init__() - self._name = name + super(DistributedReshapeImpl1, self).__init__(name) self._forward_implemented = True self._backward_implemented = False @@ -234,6 +228,10 @@ def is_output_compatible(self, dist_op): return True def is_auto_compatible(self, dist_op): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): + return False + op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] @@ -244,24 +242,13 @@ def is_auto_compatible(self, dist_op): x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping( x_shape_name) - if len(x_dims_mapping) == len(out_dims_mapping) + 2: - if out_dims_mapping[0] != x_dims_mapping[0]: - return False - if x_dims_mapping[-1] != -1 or x_dims_mapping[-2] != -1: - return False - elif len(x_dims_mapping) != len(out_dims_mapping) + 1: - return False - if is_dim_shard(x_dims_mapping[-1]): return False - for idx, item in enumerate(x_dims_mapping[:-2]): + for idx, item in enumerate(x_dims_mapping[:-1]): if out_dims_mapping[idx] != item: return False - if x_dims_mapping[-2] != out_dims_mapping[-1]: - return False - if x_shape_dims_mapping[0] != -1: return False @@ -359,7 +346,7 @@ def forward(ctx, *args, **kwargs): @staticmethod def backward(ctx, *args, **kwargs): - pass + DistributedDefaultImpl0.backward(ctx, *args, **kwargs) register_distributed_operator_impl("reshape2", diff --git a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py index e4624b51222ed..f78f1c58dbf07 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py @@ -22,22 +22,20 @@ from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping +from .dist_default import DistributedDefaultImpl0 class DistributedSoftmax(DistributedOperatorImplContainer): - def __init__(self, name): - super(DistributedSoftmax, self).__init__() - self._name = name + def __init__(self, op_type): + super(DistributedSoftmax, self).__init__(op_type) -register_distributed_operator_impl_container("softmax", - DistributedSoftmax("softmax")) +register_distributed_operator_impl_container(DistributedSoftmax("softmax")) class DistributedSoftmaxImpl(DistributedOperatorImpl): def __init__(self, name): - super(DistributedSoftmaxImpl, self).__init__() - self._name = name + super(DistributedSoftmaxImpl, self).__init__(name) self._forward_implemented = False self._backward_implemented = False @@ -48,8 +46,8 @@ def is_input_compatible(self, dist_op): axis = op_desc.attr('axis') x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) - if axis != -1 and axis != len(x_dims_mapping) - 1: - return False + # if axis != -1 and axis != len(x_dims_mapping) - 1: + # return False if is_dim_shard(x_dims_mapping[axis]): return False @@ -63,8 +61,8 @@ def is_output_compatible(self, dist_op): axis = op_desc.attr('axis') out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) - if axis != -1 and axis != len(out_dims_mapping) - 1: - return False + # if axis != -1 and axis != len(out_dims_mapping) - 1: + # return False if is_dim_shard(out_dims_mapping[axis]): return False @@ -72,6 +70,10 @@ def is_output_compatible(self, dist_op): return True def is_auto_compatible(self, dist_op): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): + return False + op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] @@ -79,11 +81,8 @@ def is_auto_compatible(self, dist_op): out_name = op_desc.output('Out')[0] x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) - if axis != -1 and axis != len(x_dims_mapping) - 1: - return False - - if is_dim_shard(x_dims_mapping[axis]): - return False + # if axis != -1 and axis != len(x_dims_mapping) - 1: + # return False if x_dims_mapping != out_dims_mapping: return False @@ -107,9 +106,13 @@ def update_dims_mapping(self, dist_op): return changed + @staticmethod + def forward(ctx, *args, **kwargs): + DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + @staticmethod def backward(ctx, *args, **kwargs): - pass + DistributedDefaultImpl0.backward(ctx, *args, **kwargs) register_distributed_operator_impl( diff --git a/python/paddle/distributed/auto_parallel/operators/dist_transpose.py b/python/paddle/distributed/auto_parallel/operators/dist_transpose.py index 8b40524e47315..e6a96fb795ef8 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_transpose.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_transpose.py @@ -22,22 +22,21 @@ from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping +from .dist_default import DistributedDefaultImpl0 class DistributedTranspose2(DistributedOperatorImplContainer): - def __init__(self, name): - super(DistributedTranspose2, self).__init__() - self._name = name + def __init__(self, op_type): + super(DistributedTranspose2, self).__init__(op_type) register_distributed_operator_impl_container( - "transpose2", DistributedTranspose2("transpose2")) + DistributedTranspose2("transpose2")) class DistributedTranspose2Impl(DistributedOperatorImpl): def __init__(self, name): - super(DistributedTranspose2Impl, self).__init__() - self._name = name + super(DistributedTranspose2Impl, self).__init__(name) self._forward_implemented = False self._backward_implemented = False @@ -48,6 +47,10 @@ def is_output_compatible(self, dist_op): return True def is_auto_compatible(self, dist_op): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): + return False + op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr perm = op_desc.attr('axis') @@ -111,9 +114,13 @@ def update_dims_mapping(self, dist_op): return changed + @staticmethod + def forward(ctx, *args, **kwargs): + DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + @staticmethod def backward(ctx, *args, **kwargs): - pass + DistributedDefaultImpl0.backward(ctx, *args, **kwargs) register_distributed_operator_impl( diff --git a/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py b/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py index 56782bec0856a..b46a1bdcc8791 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py @@ -20,18 +20,17 @@ class DistributedUpdateLossScaling(DistributedOperatorImplContainer): - def __init__(self, name): - super(DistributedUpdateLossScaling, self).__init__() - self._name = name + def __init__(self, op_type): + super(DistributedUpdateLossScaling, self).__init__(op_type) register_distributed_operator_impl_container( - "update_loss_scaling", DistributedUpdateLossScaling("update_loss_scaling")) + DistributedUpdateLossScaling("update_loss_scaling")) class DistributedUpdateLossScalingImpl(DistributedOperatorImpl): def __init__(self, name): - super(DistributedUpdateLossScalingImpl, self).__init__() + super(DistributedUpdateLossScalingImpl, self).__init__(name) self._name = name self._forward_implemented = False self._backward_implemented = True @@ -46,6 +45,11 @@ def is_output_compatible(self, dist_op): "DistributedUpdateLossScalingImpl's is_output_compatible should not be called !" ) + def is_auto_compatible(self, dist_op): + raise RuntimeError( + "DistributedUpdateLossScalingImpl's is_auto_compatible should not be called !" + ) + def update_dims_mapping(self, dist_op): raise RuntimeError( "DistributedUpdateLossScalingImpl's update_dims_mapping should not be called !" diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index 76a9faa1c8398..5a934cb61fba5 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -63,7 +63,6 @@ def __init__(self, dist_context, rank_id=0): def partition(self, serial_main_program, serial_startup_program, params_grads): - if not isinstance(serial_main_program, (Program)): raise TypeError( "main_program be paddle.fluid.framework.program, got %s here" % @@ -87,7 +86,7 @@ def partition(self, serial_main_program, serial_startup_program, serial_main_program, serial_startup_program) dist_op_context.set_dst_startup_program(partitioned_startup_prog) - # partition main program + # partition main program partitioned_main_prog, partitioned_params_grads = self.partition_main_program( serial_main_program, params_grads) @@ -281,7 +280,7 @@ def _get_dist_shape(var, dist_attr): def _partition_parameter(dist_context, src_var, dst_block, dst_varname, dst_shape): # NOTE hack to copied Parameter - # not initialized parameter, need to initialize it + # not initialized parameter, need to initialize it copied_kwargs = {} copied_kwargs['trainable'] = src_var.trainable copied_kwargs['optimize_attr'] = src_var.optimize_attr @@ -370,32 +369,25 @@ def _get_dist_op_backward_implement(backward_op, dist_context, forward_op = forward_op_id2forward_op[forward_op_id] forward_op_dist_attr = dist_context.get_op_dist_attr_for_program( forward_op) - dist_op = get_distributed_operator_impl_container(forward_op.type) - - # TODO backward should have its own impl_idx - if dist_op and forward_op_dist_attr.impl_idx >= 0 and dist_op.get_impl( \ - forward_op_dist_attr.impl_idx)._backward_implemented: - return dist_op.get_impl(forward_op_dist_attr.impl_idx) + dist_op_impl_container = get_distributed_operator_impl_container( + forward_op_dist_attr.impl_type) + dist_op_impl = dist_op_impl_container.get_impl( + forward_op_dist_attr.impl_idx) + return dist_op_impl - # NOTE trick for dist ops that only have backward implement + # # NOTE trick for dist ops that only have backward implement if backward_op.type in BACKWARD_ONLY_DIST_OPS: op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op) assert op_dist_attr.impl_idx >= 0 - return get_distributed_operator_impl_container( + dist_op_impl = get_distributed_operator_impl_container( backward_op.type).get_impl(op_dist_attr.impl_idx) - dist_op = get_distributed_operator_impl_container("default") return dist_op.get_impl(0) def _get_dist_op_forward_implement(forward_op, dist_context): dist_attr = dist_context.get_op_dist_attr_for_program(forward_op) - dist_op = get_distributed_operator_impl_container(forward_op.type) - - if dist_op and dist_attr.impl_idx >= 0 and dist_op.get_impl( - dist_attr.impl_idx)._forward_implemented: - return dist_op.get_impl(dist_attr.impl_idx) - - else: - dist_op = get_distributed_operator_impl_container("default") - return dist_op.get_impl(0) + dist_op_impl_container = get_distributed_operator_impl_container( + dist_attr.impl_type) + dist_op_impl = dist_op_impl_container.get_impl(dist_attr.impl_idx) + return dist_op_impl diff --git a/python/paddle/distributed/auto_parallel/planner.py b/python/paddle/distributed/auto_parallel/planner.py index 1dfefb41c80a3..5344af2a6698b 100755 --- a/python/paddle/distributed/auto_parallel/planner.py +++ b/python/paddle/distributed/auto_parallel/planner.py @@ -28,7 +28,7 @@ from .dist_op import DistributedOperator from .process_group import _g_process_group_map from .process_group import ProcessGroup, get_process_group -from .completion import is_elementwise_like_op +from .operators.common import is_elementwise_op from .operators.common import get_distributed_operator_impl_container from .utils import update_op_dims_mapping_by_default_dist_impl from .utils import update_op_dims_mapping_by_elementwise_like_dist_impl @@ -216,7 +216,7 @@ def _enum_valid_dist_attr_for_op(program, op, process_mesh): # compose dims mapping composed_dims_mapping_list = list( product( - *[dims_mapping_dict[key] for key in dims_mapping_dict.keys()])) + * [dims_mapping_dict[key] for key in dims_mapping_dict.keys()])) for composed_dims_mapping in composed_dims_mapping_list: op_dist_attr = OperatorDistributedAttribute() op_dist_attr.process_mesh = process_mesh @@ -237,7 +237,7 @@ def _enum_valid_dist_attr_for_op(program, op, process_mesh): dist_op = DistributedOperator(op, op_dist_attr) if dist_op_impl_container is None: - if is_elementwise_like_op(op.type): + if is_elementwise_op(op.type): changed = True valid = True try: @@ -271,7 +271,7 @@ def _enum_valid_dist_attr_for_op(program, op, process_mesh): continue # if op has distributed implements, find all valid dist attr of this op - impls = dist_op_impl_container.get_impls() + impls = dist_op_impl_container.impls for idx, impl in enumerate(impls): if impl.is_auto_compatible(dist_op): if PlanFilter.check_dims_mapping_for_op( diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index b0249356eddb1..24e587846f1ba 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -515,6 +515,7 @@ def _insert_recv_op(block, idx, tensor, src): def _insert_concat_op(block, idx, tensors, axis): """Insert concat op into block at the given block.""" inputs = {'X': tensors} + print("concat_op", inputs, flush=True) attrs = {} attrs['axis'] = axis helper = LayerHelper('concat', **locals()) diff --git a/python/paddle/distributed/auto_parallel/rules.py b/python/paddle/distributed/auto_parallel/rules.py new file mode 100644 index 0000000000000..b48c90ec4ddd3 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/rules.py @@ -0,0 +1,94 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + + +def compute_compatible_process_mesh(process_mesh_list): + """Compute the compatible process mesh given a list of process meshes.""" + if not process_mesh_list: + return None + + def _compute_compatible_process_mesh_two(pm1, pm2): + if pm1 is None: + return True, pm2 + if pm2 is None: + return True, pm1 + if pm1 == pm2: + return True, pm1 + if pm1.processes == pm2.processes: + if len(pm1.topology) >= len(pm2.topology): + return True, pm1 + else: + return True, pm2 + process_set1 = set(pm1.processes) + process_set2 = set(pm2.processes) + if process_set1.issubset(process_set2): + return True, pm2 + if process_set2.issubset(process_set1): + return True, pm1 + return False, None + + compatible_result = None + for process_mesh in process_mesh_list: + compatible, compatible_result = _compute_compatible_process_mesh_two( + compatible_result, process_mesh) + if not compatible: + return None + return copy.deepcopy(compatible_result) + + +def compute_compatible_dim_mapping(dim_mapping_list): + """Compute the compatible dim mapping given a list of dim mapping.""" + if not dim_mapping_list: + return None + + def _compute_compatible_dim_mapping_two(dm1, dm2): + if dm1 == -1: + return True, dm2 + if dm2 == -1: + return True, dm1 + if dm1 == dm2: + return True, dm1 + return False, None + + compatible_result = -1 + for mapping in dim_mapping_list: + compatible, compatible_result = _compute_compatible_dim_mapping_two( + compatible_result, mapping) + if not compatible: + return None + return compatible_result + + +def compute_compatible_dims_mapping(dims_mapping_list): + """Compute the compatible dims mapping given a list of dims mapping. + Each of dims mapping is also a list. + """ + if not dims_mapping_list: + return None + length = len(dims_mapping_list[0]) + for dims_mapping in dims_mapping_list: + if dims_mapping is None: + return None + if len(dims_mapping) != length: + return None + compatible_result = [] + for dim_mappings in zip(*dims_mapping_list): + compatible_dim_mapping = compute_compatible_dim_mapping( + list(dim_mappings)) + if compatible_dim_mapping is None: + return None + compatible_result.append(compatible_dim_mapping) + return compatible_result diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_api.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_api.py index 8593e44b3d820..7d94139e9a881 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_api.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_api.py @@ -105,7 +105,7 @@ def test_api(self): self.assertEqual(dist_op.dist_attr.process_mesh, ProcessMesh(process_mesh2)) self.assertEqual(dist_op.dist_attr.impl_type, "default") - self.assertEqual(dist_op.dist_attr.impl_idx, -2) + self.assertEqual(dist_op.dist_attr.impl_idx, 0) self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh")) data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name) @@ -138,7 +138,7 @@ def test_api(self): dist_op = dist_context.get_dist_op_for_program(last_op) self.assertEqual(dist_op.dist_attr.process_mesh, None) self.assertEqual(dist_op.dist_attr.impl_type, "default") - self.assertEqual(dist_op.dist_attr.impl_idx, -2) + self.assertEqual(dist_op.dist_attr.impl_idx, 0) self.assertFalse(dist_op.dist_attr.is_annotated("process_mesh")) data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py index 05d71aca5db2c..7047f162aad46 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py @@ -76,15 +76,13 @@ def forward(self, input): auto.shard_tensor( self.linear0.weight, dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] + "process_mesh": [0, 1, 2, 3], + "dims_mapping": [-1, 0] }) auto.shard_tensor( self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) + dist_attr={"process_mesh": [1, 5], + "dims_mapping": [0, -1]}) elif _global_parallel_strategy == "pp": auto.shard_tensor( self.linear0.weight, @@ -102,7 +100,15 @@ def forward(self, input): out = self.norm(input) out = self.linear0(out) out = F.gelu(out, approximate=True) + # if _global_parallel_strategy == "dp_mp": + # auto.shard_tensor( + # out, + # dist_attr={ + # "process_mesh": [1, 5], + # "dims_mapping": [-1, 1] + # }) out = self.linear1(out) + out = self.linear0(out) out = self.dropout(out) return out @@ -127,12 +133,13 @@ def mlp_pretrain_forward(train_program, start_program): "dims_mapping": [0, -1, -1] }) elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor( - input, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1, -1] - }) + pass + # auto.shard_tensor( + # input, + # dist_attr={ + # "process_mesh": [2, 4], + # "dims_mapping": [0, -1, -1] + # }) mlp = MLPLayer( hidden_size=hidden_size, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py index 3a28595c833e0..3575912a8b437 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py @@ -436,6 +436,13 @@ def forward(self, input): out = self.linear2(out) out = F.gelu(out, approximate=True) out = self.linear3(out) + if _global_parallel_strategy == "dp_mp_pp": + auto.shard_tensor( + out, + dist_attr={ + "process_mesh": _global_process_mesh[1], + "dims_mapping": [0, -1] + }) return out @@ -487,6 +494,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): no_grad_set=None, callbacks=None) + print("1****************", rank_id, flush=True) + print_program_with_dist_attr(complete_train_program, dist_context) + partitioner = Partitioner(dist_context, rank_id) dist_train_program, dist_startup_prog, dist_params_grads = partitioner.partition( complete_train_program, startup_program, params_grads) @@ -494,7 +504,11 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): partitioned_optimize_ops = parallelizer._apply_optimize( dist_train_program, dist_startup_prog, dist_params_grads) + print("2****************", rank_id, flush=True) reshard(dist_train_program, dist_startup_prog, rank_id, dist_context) + print("3****************", rank_id, flush=True) + print_program_with_dist_attr(dist_train_program, dist_context) + print("4****************", rank_id, flush=True) return dist_train_program, dist_startup_prog @@ -537,8 +551,8 @@ def test_mapper_dp_mp_pp(self): dist_context = DistributedContext() dist_train_program, dist_startup_prog = get_dist_prog( train_program, startup_program, dist_context, rank_id) - # if rank_id == 0: - # print_program_with_dist_attr(dist_train_program, dist_context) + if rank_id == 6: + print_program_with_dist_attr(dist_train_program, dist_context) dist_programs[rank_id] = [dist_train_program, None] rank_mapping = mapping(dist_programs, cluster) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py index ed64fa0630fa1..78ad64b1dd852 100755 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py @@ -154,7 +154,7 @@ def test_update(self): ops = train_program.global_block().ops vars = train_program.global_block().vars from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container - from paddle.distributed.auto_parallel.completion import is_elementwise_like_op + from paddle.distributed.auto_parallel.operators.common import is_elementwise_op from paddle.distributed.auto_parallel.dist_op import DistributedOperator for op in ops: @@ -163,7 +163,7 @@ def test_update(self): if dist_op_impl_container is None: op_dist_attr = dist_context.get_op_dist_attr_for_program(op) dist_op = DistributedOperator(op, op_dist_attr) - if is_elementwise_like_op(op.type): + if is_elementwise_op(op.type): changed = update_op_dims_mapping_by_elementwise_like_dist_impl( dist_op) self.assertFalse(changed) diff --git a/python/paddle/fluid/tests/unittests/test_auto_search_dist_matmul_op.py b/python/paddle/fluid/tests/unittests/test_auto_search_dist_matmul_op.py index c9cbcd1ea8efd..8c5913c66a70d 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_search_dist_matmul_op.py +++ b/python/paddle/fluid/tests/unittests/test_auto_search_dist_matmul_op.py @@ -96,7 +96,7 @@ def mlp_forward(train_program, start_program): return loss, train_program, start_program -class Testcompatible(unittest.TestCase): +class TestCompatible(unittest.TestCase): def test_matmulv2_matmul_2_compatible(self): valid_op_dist_attr_list = [] program = paddle.static.Program() @@ -123,7 +123,7 @@ def test_matmulv2_matmul_2_compatible(self): if op.type == 'matmul_v2' or op.type == 'matmul': dist_op_impl_container = get_distributed_operator_impl_container( op.type) - impls = dist_op_impl_container.get_impls() + impls = dist_op_impl_container.impls op_dist_attr = OperatorDistributedAttribute() X = op.input_arg_names[0] Y = op.input_arg_names[1] @@ -174,7 +174,7 @@ def test_matmulv2_matmul_2_compatible(self): op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1]) op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, -1]) op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1]) - self.assertFalse(impls[2].is_auto_compatible( + self.assertTrue(impls[2].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1]) self.assertFalse(impls[2].is_auto_compatible( @@ -220,7 +220,7 @@ def test_matmulv2_matmul_1_compatible(self): if op.type == 'matmul_v2' or op.type == 'matmul': dist_op_impl_container = get_distributed_operator_impl_container( op.type) - impls = dist_op_impl_container.get_impls() + impls = dist_op_impl_container.impls op_dist_attr = OperatorDistributedAttribute() X = op.input_arg_names[0] Y = op.input_arg_names[1] @@ -261,7 +261,7 @@ def test_matmulv2_matmul_1_compatible(self): op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, 1]) op_dist_attr.set_input_dims_mapping(Y, [-1, -1, 1, -1]) op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1]) - self.assertFalse(impls[1].is_auto_compatible( + self.assertTrue(impls[1].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1]) self.assertFalse(impls[1].is_auto_compatible( @@ -307,7 +307,7 @@ def test_matmulv2_matmul_0_compatible(self): if op.type == 'matmul_v2' or op.type == 'matmul': dist_op_impl_container = get_distributed_operator_impl_container( op.type) - impls = dist_op_impl_container.get_impls() + impls = dist_op_impl_container.impls op_dist_attr = OperatorDistributedAttribute() X = op.input_arg_names[0] Y = op.input_arg_names[1] @@ -362,7 +362,7 @@ def test_matmulv2_matmul_0_compatible(self): op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1]) op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, 1]) op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, 1]) - self.assertFalse(impls[0].is_auto_compatible( + self.assertTrue(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_output_dims_mapping(out, [0, -1, -1, 1]) self.assertFalse(impls[0].is_auto_compatible( diff --git a/python/paddle/fluid/tests/unittests/test_auto_search_dist_op.py b/python/paddle/fluid/tests/unittests/test_auto_search_dist_op.py index 8f53a0c765d4c..4cb58eac7cc41 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_search_dist_op.py +++ b/python/paddle/fluid/tests/unittests/test_auto_search_dist_op.py @@ -96,24 +96,7 @@ def mlp_forward(train_program, start_program): return loss, train_program, start_program -class Testcompatible(unittest.TestCase): - def test_raise_compatible(self): - valid_op_dist_attr_list = [] - program = paddle.static.Program() - startup_program = paddle.static.Program() - loss, program, start_program = mlp_forward(program, startup_program) - ops = program.global_block().ops - for idx, op in enumerate(ops): - if op.type == 'transpose2': - op_dist_attr = OperatorDistributedAttribute() - dist_op = DistributedOperator(op, op_dist_attr) - impls = DistributedOperatorImpl() - try: - impls.is_auto_compatible(dist_op) - except NotImplementedError: - e = False - self.assertTrue(e == False) - +class TestCompatible(unittest.TestCase): def test_reshape_remove_compatible(self): valid_op_dist_attr_list = [] program = paddle.static.Program() @@ -124,7 +107,7 @@ def test_reshape_remove_compatible(self): if op.type == 'reshape2': dist_op_impl_container = get_distributed_operator_impl_container( op.type) - impls = dist_op_impl_container.get_impls() + impls = dist_op_impl_container.impls op_dist_attr = OperatorDistributedAttribute() op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1, -1, -1]) @@ -172,64 +155,6 @@ def test_reshape_remove_compatible(self): self.assertFalse(impls[1].is_auto_compatible( DistributedOperator(op, op_dist_attr))) - def test_reshape_remove_two_compatible(self): - valid_op_dist_attr_list = [] - program = paddle.static.Program() - startup_program = paddle.static.Program() - loss, program, start_program = mlp_forward(program, startup_program) - ops = program.global_block().ops - for idx, op in enumerate(ops): - if op.type == 'reshape2': - dist_op_impl_container = get_distributed_operator_impl_container( - op.type) - impls = dist_op_impl_container.get_impls() - op_dist_attr = OperatorDistributedAttribute() - op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], - [-1, -1, -1]) - op_dist_attr.set_output_dims_mapping(op.output_arg_names[0], - [-1]) - op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], - [-1, -1, -1, -1]) - dist_op = DistributedOperator(op, op_dist_attr) - self.assertTrue(impls[1].is_auto_compatible( - DistributedOperator(op, op_dist_attr))) - op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], - [-1, 1, 0]) - self.assertFalse(impls[1].is_auto_compatible( - DistributedOperator(op, op_dist_attr))) - op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], - [0, 1, 1]) - self.assertFalse(impls[1].is_auto_compatible( - DistributedOperator(op, op_dist_attr))) - - op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], - [1, -1, -1, -1]) - self.assertFalse(impls[1].is_auto_compatible( - DistributedOperator(op, op_dist_attr))) - op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], - [-1, 1, 1]) - self.assertFalse(impls[1].is_auto_compatible( - DistributedOperator(op, op_dist_attr))) - - op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], - [-1, -1, -1, 1]) - self.assertFalse(impls[1].is_auto_compatible( - DistributedOperator(op, op_dist_attr))) - - op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], - [-1, 1, -1, -1]) - self.assertFalse(impls[1].is_auto_compatible( - DistributedOperator(op, op_dist_attr))) - op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], - [-1, -1, 1, -1]) - self.assertFalse(impls[1].is_auto_compatible( - DistributedOperator(op, op_dist_attr))) - - op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], - [1, -1, -1]) - self.assertFalse(impls[1].is_auto_compatible( - DistributedOperator(op, op_dist_attr))) - def test_reshape_add_compatible(self): valid_op_dist_attr_list = [] program = paddle.static.Program() @@ -240,7 +165,7 @@ def test_reshape_add_compatible(self): if op.type == 'reshape2': dist_op_impl_container = get_distributed_operator_impl_container( op.type) - impls = dist_op_impl_container.get_impls() + impls = dist_op_impl_container.impls op_dist_attr = OperatorDistributedAttribute() op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1]) op_dist_attr.set_output_dims_mapping(op.output_arg_names[0], @@ -298,7 +223,7 @@ def test_transpose_compatible(self): if op.type == 'transpose2': dist_op_impl_container = get_distributed_operator_impl_container( op.type) - impls = dist_op_impl_container.get_impls() + impls = dist_op_impl_container.impls op_dist_attr = OperatorDistributedAttribute() op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1, -1]) @@ -349,7 +274,7 @@ def test_softmax_compatible(self): if op.type == 'softmax': dist_op_impl_container = get_distributed_operator_impl_container( op.type) - impls = dist_op_impl_container.get_impls() + impls = dist_op_impl_container.impls op_dist_attr = OperatorDistributedAttribute() op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1, -1]) @@ -379,7 +304,7 @@ def test_embedding_compatible(self): if op.type == 'c_embedding' or op.type == 'lookup_table_v2': dist_op_impl_container = get_distributed_operator_impl_container( op.type) - impls = dist_op_impl_container.get_impls() + impls = dist_op_impl_container.impls op_dist_attr = OperatorDistributedAttribute() op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1, -1]) From 41a478dfaf5894b8ec7509886299ede74657c57b Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Tue, 18 Jan 2022 06:25:03 +0000 Subject: [PATCH 04/11] Remove unnecessary modification --- .../paddle/distributed/auto_parallel/utils.py | 1415 ----------------- .../test_auto_parallel_completion.py | 31 +- .../unittests/test_auto_parallel_mapper.py | 18 +- .../unittests/test_auto_parallel_searcher.py | 4 +- 4 files changed, 16 insertions(+), 1452 deletions(-) delete mode 100644 python/paddle/distributed/auto_parallel/utils.py diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py deleted file mode 100644 index 1867731974f11..0000000000000 --- a/python/paddle/distributed/auto_parallel/utils.py +++ /dev/null @@ -1,1415 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License - -import os -import copy -import paddle -import threading -import numpy as np -import warnings -import logging -from functools import reduce - -import paddle.fluid.core as core -from paddle.framework.io import _to_LodTensor -from paddle.distributed.fleet.meta_optimizers.common import OpRole -from paddle.fluid.io import is_parameter, is_belong_to_optimizer -from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute - - -def is_valid_list_index(list, index): - if index >= -len(list) and index < len(list): - return True - else: - return False - - -def is_dim_shard(mapping): - if mapping != -1: - return True - else: - return False - - -def is_dim_replicate(mapping): - if mapping == -1: - return True - else: - return False - - -def compute_compatible_dim_mapping(dim_mappings): - if not dim_mappings: - return None - compatible_mapping = dim_mappings[0] - for mapping in dim_mappings: - if compatible_mapping == -1: - compatible_mapping = mapping - elif mapping == -1: - continue - elif compatible_mapping == mapping: - continue - else: - return None - return compatible_mapping - - -def compute_compatible_dims_mapping(dims_mapping_list): - if not dims_mapping_list: - return None - length = len(dims_mapping_list[0]) - for dims_mapping in dims_mapping_list: - assert dims_mapping is not None, \ - "Dims mapping must not be None for compatible computation" - assert len(dims_mapping) == length, \ - "The length of dims_mapping in list must be same for compatible computation." - compatible_result = [] - for dim_mappings in zip(*dims_mapping_list): - compatible_dim_mapping = compute_compatible_dim_mapping( - list(dim_mappings)) - if compatible_dim_mapping is None: - return None - compatible_result.append(compatible_dim_mapping) - return compatible_result - - -def compute_compatible_process_mesh(process_mesh_list): - compatible_process_mesh = None - if not process_mesh_list: - return compatible_process_mesh - for process_mesh in process_mesh_list: - if process_mesh is not None: - if compatible_process_mesh is None or compatible_process_mesh == process_mesh: - compatible_process_mesh = process_mesh - else: - return None - return compatible_process_mesh - - -def compute_compatible_and_update_dim_mapping(dims_mapping_list, index_list): - assert len(dims_mapping_list) == len(index_list) - changed = False - dim_mappings = [] - for i in range(len(dims_mapping_list)): - assert is_valid_list_index(dims_mapping_list[i], index_list[i]) - dim_mappings.append(dims_mapping_list[i][index_list[i]]) - compatible_dim_mapping = compute_compatible_dim_mapping(dim_mappings) - if compatible_dim_mapping is None: - return False - for i in range(len(dims_mapping_list)): - if compatible_dim_mapping != dims_mapping_list[i][index_list[i]]: - dims_mapping_list[i][index_list[i]] = compatible_dim_mapping - changed = True - return changed - - -def append_distributed_attr_suffix(name): - """ - Append auto parallel suffix for distributed attribute name. - """ - return name + core.kAutoParallelSuffix() - - -def remove_distributed_attr_suffix(name): - """ - Remove auto parallel suffix from distributed attribute name. - """ - return name.strip(core.kAutoParallelSuffix()) - - -def check_distributed_attr_for_program(program, dist_context=None): - from .dist_context import get_default_distributed_context - if dist_context is None: - dist_context = get_default_distributed_context() - assert dist_context.is_initialized_for_program(), \ - "Distributed attributes must be initialized before check." - for block in program.blocks: - for tensor in block.vars.values(): - dist_tensor = dist_context.get_dist_tensor_for_graph(tensor) - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( - tensor) - if (tensor_dist_attr is not None) and (not dist_tensor.is_valid()): - return False - for op in block.ops: - dist_op = dist_context.get_dist_op_for_graph(tensor) - op_dist_attr = dist_context.get_op_dist_attr_for_program(op) - if (op_dist_attr is not None) and (not dist_op.is_valid()): - return False - return True - - -def print_program_with_dist_attr(program, dist_context=None): - """ - This function reuses the original program output ability with a distributed context. - Using lock can avoid multiple threads change the default distributed context simultaneously. - """ - lock = threading.Lock() - lock.acquire() - from .dist_context import get_default_distributed_context - from .dist_context import set_default_distributed_context - if dist_context is None: - dist_context = get_default_distributed_context() - print(program) - else: - original_default_context = get_default_distributed_context() - set_default_distributed_context(dist_context) - print(program) - set_default_distributed_context(original_default_context) - lock.release() - - -def _get_comm_group(processes, shape, axis, rank): - """ - Given a rank and the processes mesh the rank belongs to, - compute the communication peers of the rank based on the give axis in the mesh. - - Example: 16 processes managed in a 4-Dimensinal mesh with shape of [2, 2, 2, 2]. - the rank communication peers of rank 0 (included) are following: - in axis 0: [0, 1] - in axis 1: [0, 2] - in axis 2: [0, 4] - in axis 3: [0, 8] - """ - - # NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous - # tricks to support processes mesh when it is not start with 0 or continuous - assert rank in processes, "rank [{}] is NOT in processes group {}".format( - rank, processes) - rank_relatvie = processes.index(rank) - coordinate = _linear_idx2coordinate(shape, rank_relatvie) - coordinates_in_group = [coordinate[:] for i in range(shape[axis])] - - # select comm group - for i in range(shape[axis]): - coordinates_in_group[i][axis] = i - - ranks_in_group_relative = [ - _coordinate2linear_idx(shape, coordinate) - for coordinate in coordinates_in_group - ] - ranks_in_group = [processes[idx] for idx in ranks_in_group_relative] - - return sorted(ranks_in_group) - - -def _get_idx_in_axis(processes, shape, axis, rank): - """ - Given a rank and the processes mesh the rank belongs to, - compute the index of the rank in given axis. - - Example: 27 processes managed in a 3-Dimensinal mesh with shape of [3, 3, 3]. - the index of rank 22 are: - in axis 0: 1 - in axis 1: 1 - in axis 2: 2 - """ - - # NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous - # tricks to support processes mesh when it is not start with 0 or continuous - rank_relatvie = processes.index(rank) - coordinate = _linear_idx2coordinate(shape, rank_relatvie) - return coordinate[axis] - - -def _coordinate2linear_idx(mesh_shape, coordinate): - """ - convert a coordinate in multidimensional mesh space into a scala idx in linear space. - - it use Row-major order for dimension conversion. - so it has: [most_significant_dim, ..., least_significant_dim] - assume: - - the size of i-th dimension to be: S[i] - the index of j-th dimension is: I[j] - - linear_idx of a n dimensional coordinate is: - - I[n-1] * (S[n-2] * S[n-3] * S[n-4] * .... S[0]) + - I[n-2] * ( S[n-3] * S[n-4] * .... S[0]) + - I[n-3] * ( S[n-4] * .... S[0]) + - ... - I[1] * ( S[0]) + - I[0] - - """ - # NOTE the following function work based on a strong an assumption - # that the processes in mesh are - # 1. starts from 0 - # 2. continuous - # it will be wrong if ths above condition doesnot meet, - # e.g. process_mesh = { process_groups = [7, 8, 9,10, 12, 13, 14, 15], mesh = [2, 4]} - # if you want a more general mapping, you should use cartesian product - - assert len(mesh_shape) == len( - coordinate - ), "coordinate should have the same size as mesh shape, but got shape: {}, coordinate: {}".format( - mesh_shape, coordinate) - for i in range(len(mesh_shape)): - assert coordinate[ - i] >= 0, "index in dimension [{}] is least than zero. coordinate: {}".format( - i, coordinate) - assert coordinate[i] < mesh_shape[ - i], "index beyond extent in dimension [{}]. shape: {}, coordinate: {}".format( - i, mesh_shape, coordinate) - - base = mesh_shape[-1] - linear_idx = coordinate[-1] - - # row major order - for i in range(len(mesh_shape) - 2, -1, -1): - linear_idx += base * coordinate[i] - base *= mesh_shape[i] - - return linear_idx - - -def _linear_idx2coordinate(mesh_shape, linear_idx): - """ - mapping a linear scala into multidimensional mesh space, return it coordinate in that space. - - it is the inverse function of _coordinate2linear_idx. - assume: - - the size of i-th dimension to be: S[i] - the index of j-th dimension is: I[j] - - the coordinate given linear_idx is: - - I[0] = linear_idx % S[0] - I[0] = (linear_idx / S[0]) % S[1] - I[0] = (linear_idx / (S[0] * S[1])) % S[2] - .... - - """ - - assert linear_idx >= 0, "linear index [{}] is least than zero".format( - linear_idx) - assert linear_idx < np.prod( - mesh_shape - ), "linear index beyond the extent of mesh shape. shape: {}, linear index: {}".format( - mesh_shape, linear_idx) - - base = 1 - coordinate = [-1] * len(mesh_shape) - - for i in reversed(range(len(mesh_shape))): - offset = linear_idx / base - coordinate[i] = int(offset % mesh_shape[i]) - base *= mesh_shape[i] - - # row major order - return coordinate - - -def _get_corresponding_rank(dist_context, target_mesh, rank): - - # TODO(JZ-LIANG) a hack method to support varying mesh in Pipeline parallelism case. - # we assume that all mesh are evenly divide from a parent mesh and should have same size. - # to revise this in future. - - coordinate = None - for mesh in dist_context.process_meshes: - if rank in mesh.processes and mesh.topology == target_mesh.topology: - coordinate = _linear_idx2coordinate(mesh.topology, - mesh.processes.index(rank)) - break - - assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format( - rank) - return target_mesh.processes[_coordinate2linear_idx(mesh.topology, - coordinate)] - - -def _get_unshard_dist_shape(var, dist_attr): - var_shape = var.shape - mapping = dist_attr.dims_mapping - mesh = dist_attr.process_mesh.topology - assert len(var_shape) == len( - mapping - ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format( - var_shape, mapping) - new_shape = [] - for idx in range(len(var_shape)): - if var_shape[idx] == -1 or mapping[idx] == -1: - new_shape.append(var_shape[idx]) - else: - new_shape.append(var_shape[idx] * mesh[mapping[idx]]) - - return new_shape - - -def make_data_unshard(dist_main_prog, dist_startup_prog, dist_context=None): - from .dist_context import get_default_distributed_context - if dist_context is None: - dist_context = get_default_distributed_context() - - for var in dist_main_prog.list_vars(): - if var.is_data: - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( - var) - inverse_shape = _get_unshard_dist_shape(var, tensor_dist_attr) - var.desc.set_shape(inverse_shape) - dim_mapping = tensor_dist_attr.dims_mapping - dim_mapping = [-1] * len(dim_mapping) - tensor_dist_attr.dims_mapping = dim_mapping - dist_context.set_tensor_dist_attr_for_program(var, tensor_dist_attr) - - -def _update_addition_info(addition_info): - """ Update default addition_info with inputs """ - add_info = {"epoch": 0, "batch": 0, "batch_size": 0} - if not addition_info: - return add_info - elif not isinstance(addition_info, dict): - raise TypeError("The type of 'addition_info' should be 'dict', " - "but got '{}'.".format(str(type(addition_info)))) - else: - for item, value in addition_info.items(): - if item not in ["epoch", "batch", "batch_size"]: - raise ValueError( - "The key of 'addition_info' should be one of the " - "['epoch', 'batch', 'batch_size'], but got '{}'." - .format(str(item))) - if not isinstance(value, int): - raise ValueError( - "The value of 'addition_info' should be 'int', " - "but got '{}'.".format(str(type(value)))) - add_info[item] = value - return add_info - - -def _check_valid_path(file_path): - """ Validity check of input file path """ - if not file_path: - return file_path - elif isinstance(file_path, list): - for file in file_path: - if not isinstance(file, str): - raise TypeError("The type of file path should be 'str', " - "but got '{}'.".format(str(type(file)))) - if not os.path.exists(file): - raise ValueError("The file path '{}' does not exist." - .format(file)) - return file_path - else: - raise TypeError("The type of file path should be 'list', " - "but got '{}'.".format(str(type(file_path)))) - - -def _check_param_dict(param_dict): - if not param_dict: - raise ValueError("'param_dict' cannot be None.") - elif not isinstance(param_dict, dict): - raise TypeError("The type of 'param_dict' should be 'dict', " - "but got '{}'.".format(str(type(param_dict)))) - else: - for name, value in param_dict.items(): - if not isinstance(name, str): - raise TypeError( - "The type of key of 'param_dict' should be 'str', " - "but got '{}'.".format(str(type(name)))) - if not isinstance(value, paddle.fluid.LoDTensor): - raise TypeError( - "The type of value of 'param_dict' should be 'LoDTensor', " - "but got '{}'.".format(str(type(value)))) - return param_dict - - -def _check_dist_attr(dist_attr): - if not dist_attr: - return dist_attr - elif not isinstance(dist_attr, dict): - raise TypeError("The type of 'dist_attr' should be 'dict', " - "but got '{}'.".format(str(type(dist_attr)))) - else: - for name, value in dist_attr.items(): - if not isinstance(name, str): - raise TypeError( - "The type of param name of 'dist_attr' should be 'str', " - "but got '{}'.".format(str(type(name)))) - if not isinstance(value, dict): - raise TypeError( - "The type of distributed attribute should be 'dict', " - "but got '{}'".format(str(type(value)))) - attr = ['process_shape', 'process_group', 'dims_mapping'] - if list(value.keys()) != attr: - raise ValueError( - "The key of distributed attribute should be " - "'['process_shape', 'process_group', 'dims_mapping']', " - "but got {}.".format(str(value.keys()))) - return dist_attr - - -def save_distributed_checkpoint(program, - checkpoint_path, - dist_attr_path, - addition_info=None, - is_integrated=False, - dist_context=None): - """ - Save model parameter state, optimzer state, distributed attribute and - additional information of each rank. - - Args: - program(Program): The program to be saved. - checkpoint_path(str): The path of the checkpoint file to be saved. - dist_attr_path(str): The path of distributed attribute file to be saved. - addition_info(dict, optional): Additional information, key should be selected in ['epoch', 'batch', 'batch_size']. - Default values are 0, when 'addition_info' is None. Default: None. - is_integrated(bool, optional): Whether to integrate param before save. Default: False. - dist_context(DistributedContext ,optional): collect related distributed information for program - - Returns: - None - - Examples: - .. code-block:: python - - path = os.path.join("./output", "step_%d" % step) - os.makedirs(path, exist_ok=True) - add_info = {'batch': step, "batch_size": global_batch_size} - save_distributed_checkpoint(program, path, path, add_info) - """ - from .dist_context import get_default_distributed_context - - assert isinstance(program, paddle.fluid.framework.Program) - assert isinstance(is_integrated, bool) - if dist_context is None: - dist_context = get_default_distributed_context() - addition_info = _update_addition_info(addition_info) - - if not is_integrated: - _save_distributed_state_dict(program, addition_info, checkpoint_path) - _save_distributed_attribute(program, dist_attr_path, dist_context) - else: - # TODO: integrate param before save - raise NotImplementedError( - "Integrating parameter has not been implemented.") - - -def load_distributed_checkpoint(checkpoint_path, dist_attr_path): - """ - Load parameter, optimizer, distributed attribute and addition_info. - - Args: - checkpoint_path(list[str]): model parameter file path, must be in order of rank id. - dist_attr_path(list[str]): distributed attribute file path, must be in order of rank id. - - Returns: - param_dict(dict): parameters' value of all ranks. - dist_attr(dict): parameters' distributed attribute. - addition_info(dict): additional information user saved in last training. - - Notes: - The return, 'addition_info', is belonging to the first file of checkpoint_path by default. - - Examples: - .. code-block:: python - - ckpt_path = ['./model_state_rank0.pdmodel', - './model_state_rank1.pdmodel'] - dist_attr_path = ['./dist_attr_rank0.pdattr', - './dist_attr_rank1.pdattr'] - param_dict, dist_attr, add_info = load_distributed_checkpoint(ckpt_path, dist_attr_path) - """ - assert _check_valid_path(checkpoint_path), \ - "'checkpoint_path' cannot be None." - assert _check_valid_path(dist_attr_path), \ - "'dist_attr_path' cannot be None." - - state_dict_info = _load_distributed_state_dict(checkpoint_path) - dist_attr = _load_distributed_attribute(dist_attr_path) - param_dict = state_dict_info["model"] - addition_info = state_dict_info["addition_info"] - return param_dict, dist_attr, addition_info - - -def load_checkpoint_into_program(checkpoint_path, - dist_attr_path, - program, - dist_context=None): - """ - Load parameter, optimizer, distributed attribute and addition_info into model. - - Args: - checkpoint_path(list[str]): model parameter file path, must be in order of rank id. - dist_attr_path(list[str]): distributed attribute file path, must be in order of rank id. - program(Program): the program to be updated with checkpoint_path. - dist_context(DistributedContext ,optional): collect related distributed information for program - - Returns: - addition_info(dict): user saved in last train. - - Notes: - The return, 'addition_info', is belonging to the first file of checkpoint_path by default. - - Examples: - .. code-block:: python - - exe.run(startup_program) - ckpt_path = ['./model_state_rank0.pdmodel', - './model_state_rank1.pdmodel'] - dist_attr_path = ['./dist_attr_rank0.pdattr', - './dist_attr_rank1.pdattr'] - load_checkpoint_into_program(ckpt_path, dist_attr_path, main_program) - """ - from .dist_context import get_default_distributed_context - - assert isinstance(program, paddle.fluid.framework.Program) - assert _check_valid_path(checkpoint_path), \ - "'checkpoint_path' cannot be None." - assert _check_valid_path(dist_attr_path), \ - "'dist_attr_path' cannot be None." - if dist_context is None: - dist_context = get_default_distributed_context() - all_state_dict_info = _load_distributed_state_dict(checkpoint_path) - all_pre_dist_attr = _load_distributed_attribute(dist_attr_path) - all_cur_dist_attr = get_dist_attr(program, dist_context) - all_param_dict = all_state_dict_info["model"] - addition_info = all_state_dict_info["addition_info"] - sliced_param_dict = merge_and_slice_parameter( - all_param_dict, all_pre_dist_attr, all_cur_dist_attr) - load_parameter_into_program(sliced_param_dict, program) - - return addition_info - - -def load_parameter_into_program(param_dict, program): - """ - Load parameters into program. - - Args: - param_dict(dict): parameters' name and value. - program(Program): the program to be updated - """ - _check_param_dict(param_dict) - assert program and isinstance(program, paddle.fluid.framework.Program) - program.set_state_dict(param_dict) - - -def _save_distributed_attribute(program, dist_attr_path, dist_context): - """ Save distributed attribute of all parameters """ - # TODO: just save a complete distributed attribute file - rank_id = paddle.distributed.get_rank() - dist_attr_name = os.path.join(dist_attr_path, - "dist_attr_rank{}.pdattr".format(rank_id)) - dist_attr_dict = { - "model": get_dist_attr(program, dist_context), - "world_size": paddle.distributed.get_world_size() - } - paddle.save(dist_attr_dict, dist_attr_name) - logging.info("Already saved distributed attribute to '{}'.".format( - dist_attr_path)) - - -def _load_distributed_attribute(dist_attr_path): - """ Load parameters' distributed attribute from dist_attr_path """ - total_dist_attr = {} - for dist_attr_file in dist_attr_path: - dist_attr = paddle.load(dist_attr_file) - pre_world_size = dist_attr["world_size"] - assert pre_world_size == len(dist_attr_path), \ - "The number of 'dist_attr_path' must be equal to the last training world size." - for name, attr in dist_attr["model"].items(): - if name not in total_dist_attr: - total_dist_attr[name] = attr - - return total_dist_attr - - -def _save_distributed_state_dict(program, addition_info, checkpoint_path): - """ Save parameters' state_dict """ - rank = paddle.distributed.get_rank() - ckpt_file_name = os.path.join(checkpoint_path, - "model_state_rank{}.pdmodel".format(rank)) - state_dict = { - "model": program.state_dict(), - "world_size": paddle.distributed.get_world_size(), - "addition_info": addition_info - } - paddle.save(state_dict, ckpt_file_name) - logging.info("Already saved model to '{}'.".format(checkpoint_path)) - - -def _load_distributed_state_dict(checkpoint_path): - """ Load parameters' state_dict from checkpoint_path """ - all_state_dict = {} - for idx, ckpt_file in enumerate(checkpoint_path): - state_dict_info = paddle.load(ckpt_file, return_numpy=True) - pre_world_size = state_dict_info["world_size"] - assert pre_world_size == len(checkpoint_path), \ - "The number of 'checkpoint_path' must be equal to the last training world size." - if idx == 0: - addition_info = state_dict_info["addition_info"] - for name, value in state_dict_info["model"].items(): - if name in all_state_dict: - all_state_dict[name].append(np.array(value)) - else: - all_state_dict[name] = [np.array(value)] - - all_state_dict_info = { - "model": all_state_dict, - "addition_info": addition_info - } - return all_state_dict_info - - -def get_dist_attr(program, dist_context=None): - """ - Get distributed attribute of current rank. - - Args: - program(Program): main program for training - """ - from .dist_context import get_default_distributed_context - - assert isinstance(program, paddle.fluid.framework.Program) - if dist_context is None: - dist_context = get_default_distributed_context() - dist_attr = {} - for var in program.list_vars(): - if is_parameter(var) or is_belong_to_optimizer(var): - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( - var) - process_mesh = tensor_dist_attr.process_mesh - dims_mapping = tensor_dist_attr.dims_mapping - dist_attr[var.name] = { - "process_shape": process_mesh.topology, - "process_group": process_mesh.processes, - "dims_mapping": dims_mapping - } - return dist_attr - - -def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr): - """ - Merge parameters with previous dist_attr and slice parameters with current dist_attr - - Arags: - dist_param_dict(dict): parameters' value of all ranks. - pre_dist_attr(dict): parameters' dist_attr of last training process. - cur_dist_attr(dict): parameters' dist_attr of current training process. - - Returns: - dist_param_dict(dict): parameters' value of current rank. - """ - assert _check_dist_attr(pre_dist_attr), "'pre_dist_attr' cannot be None." - assert _check_dist_attr(cur_dist_attr), "'pre_dist_attr' cannot be None." - assert isinstance(dist_param_dict, dict), \ - "The type of 'dist_param_dict' should be 'dict', but got {}.".format( - str(type(dist_param_dict))) - for name, value in dist_param_dict.items(): - if not isinstance(name, str): - raise TypeError("The key of 'dist_param_dict' is parameter's name, " - "and its type should be 'str', but got {}." - .format(str(type(name)))) - if not isinstance(value, list) or not all( - isinstance(v, np.ndarray) for v in value): - raise TypeError( - "The value of 'dist_param_dict' is parameter's value of all ranks, " - "and its type should be 'list(numpy.ndarray)'.") - - param_not_in_pre = [] - param_not_in_cur = [] - logging.info("Start to merge and slice parameters.") - for var_name in cur_dist_attr.keys(): - if var_name not in pre_dist_attr: - param_not_in_pre.append(var_name) - continue - - pre_attr = pre_dist_attr[var_name] - cur_attr = cur_dist_attr[var_name] - if pre_attr == cur_attr: - # skip merge and slice - rank_id = paddle.distributed.get_rank() - index = cur_attr["process_group"].index(rank_id) - param = dist_param_dict[var_name][index] - dist_param_dict[var_name] = _to_LodTensor(param) - continue - - pre_param = dist_param_dict[var_name] - pre_dims_mapping = pre_attr["dims_mapping"] - cur_dims_mapping = cur_attr["dims_mapping"] - if len(set(pre_dims_mapping)) > 1 or -1 not in pre_dims_mapping: - complete_param = _merge_parameter_with_dist_attr(pre_param, - pre_attr) - dist_param_dict[var_name] = complete_param - else: - complete_param = pre_param[0] - dist_param_dict[var_name] = _to_LodTensor(complete_param) - - if len(set(cur_dims_mapping)) > 1 or -1 not in cur_dims_mapping: - sliced_param = _slice_parameter_with_dist_attr(complete_param, - cur_attr) - dist_param_dict[var_name] = sliced_param - - for var_name in pre_dist_attr: - if var_name not in cur_dist_attr: - param_not_in_cur.append(var_name) - dist_param_dict.pop(var_name) - - if param_not_in_pre: - warnings.warn("Parameters '{}' are not found in last training process." - .format(str(param_not_in_pre))) - if param_not_in_cur: - warnings.warn( - "Parameters '{}' are not found in current training process." - .format(str(param_not_in_cur))) - - return dist_param_dict - - -def _merge_parameter_with_dist_attr(param_list, dist_attr): - """ Merge parameter with distributed attribute """ - from .reshard import _compute_complete_shape, _compute_partition_index - - dims_mapping = dist_attr["dims_mapping"] - process_shape = dist_attr["process_shape"] - process_group = dist_attr["process_group"] - # get the complete shape of the parameter - complete_shape = _compute_complete_shape(param_list[0].shape, process_shape, - dims_mapping) - # merge the parameter with dist_attr - partition_param_list = [] - merged_partiton = [] - for process in process_group: - partition_index = _compute_partition_index( - process, complete_shape, dims_mapping, process_shape, process_group) - index = process_group.index(process) - if partition_index not in merged_partiton: - merged_partiton.append(partition_index) - _merge_parameter(partition_param_list, param_list[index], - partition_index, complete_shape) - - assert len(partition_param_list) == 1 or not partition_param_list, \ - "Fail to merge parameter" - complete_param = _to_LodTensor(partition_param_list[0][0]) - return complete_param - - -def _slice_parameter_with_dist_attr(param, dist_attr): - """ Slice parameter with distributed attribute """ - param = np.array(param) if isinstance(param, - paddle.fluid.LoDTensor) else param - dims_mapping = dist_attr["dims_mapping"] - process_shape = dist_attr["process_shape"] - process_group = dist_attr["process_group"] - # slice the parameter with dist_attr - partition_index_list = _get_split_indices(param.shape, dims_mapping, - process_shape, process_group) - sliced_param_list = _slice_parameter(param, partition_index_list, - len(partition_index_list)) - # get the current parameter's index in sliced_param_list - rank_id = paddle.distributed.get_rank() - sliced_param_index = _get_sliced_param_index( - rank_id, param.shape, dims_mapping, process_shape, process_group) - sliced_param = _to_LodTensor(sliced_param_list[sliced_param_index]) - return sliced_param - - -def _merge_parameter(partition_param_list, param, partition_index, - complete_shape): - """ - Merge partitial parameters to a complete one. - - Returns: - None - - Examples: - .. code-block:: python - - import numpy as np - partition_param_list = [(np.array([[[1.11, 1.12]]]), [[0,1],[0,1],[0,2]])] - param = np.array([[[1.13, 1.14]]]) - partition_index = [[0,1],[0,1],[2,4]] - - _merge_parameter(partition_param_list, param, partition_index) - # partition_param_list: [(np.array([[[1.11, 1.12, 1.13, 1.14]]]), [[0,1],[0,1],[0,4]])] - """ - from .reshard import _compute_concat_info - - if len(partition_param_list) == 1: - is_complete_data = True - for idx, item in enumerate(partition_param_list[0][1]): - if item[0] != 0 or item[1] != complete_shape[idx]: - is_complete_data = False - break - if is_complete_data: - return - - if not partition_param_list: - partition_param_list.append((param, partition_index)) - else: - i = 0 - while i < len(partition_param_list): - concat_axis, first_order, new_partition = _compute_concat_info( - partition_param_list[i][1], partition_index) - if concat_axis != -1: - if first_order == 0: - new_param = np.concatenate( - (partition_param_list[i][0], param), axis=concat_axis) - else: - new_param = np.concatenate( - (param, partition_param_list[i][0]), axis=concat_axis) - - partition_param_list.pop(i) - _merge_parameter(partition_param_list, new_param, new_partition, - complete_shape) - break - i += 1 - - -def _slice_parameter(complete_param, partition_index_list, length): - """ - Slice a complete parameter. - - Returns: - sliced_param_list(list): sliced parameters with 'partition_index_list' - - Examples: - .. code-block:: python - - import numpy as np - complete_param = np.array([[[1.11, 1.12, 1.13, 1.14, 1.15, 1.16]]]) - rank = 2 - complete_shape = [1, 1, 6] - dims_mapping = [-1, -1, 0] - process_shape = [3] - process_group = [0, 1, 2] - - sliced_param_list = _slice_parameter(complete_param, [[], [], [2, 4]], 3) - # [array([[[1.11, 1.12]]]), array([[[1.13, 1.14]]]), array([[[1.15, 1.16]]])] - """ - sliced_param_list = [] - axis = len(complete_param.shape) - length - sliced_param = np.split( - complete_param, partition_index_list[axis], axis=axis) - if length == 1: - return sliced_param - for param in sliced_param: - sliced_param_list.extend( - _slice_parameter(param, partition_index_list, length - 1)) - return sliced_param_list - - -def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape, - process_group): - """ - Get sliced_param's index of current rank in all sliced parameters list. - - Returns: - sliced_param_index(int): the index of sliced param in sliced_param_list - - Examples: - .. code-block:: python - - import numpy as np - complete_param = np.array([[[1.11, 1.12, 1.13, 1.14, 1.15, 1.16]]]) - rank = 2 - complete_shape = [1, 1, 6] - dims_mapping = [-1, -1, 0] - process_shape = [3] - process_group = [0, 1, 2] - - slice_param = _slice_parameter(complete_param, [[], [], [2, 4]], 3) - # slice_param: - # [array([[[1.11, 1.12]]]), array([[[1.13, 1.14]]]), array([[[1.15, 1.16]]])] - - index = _get_sliced_param_index(rank, complete_shape, dims_mapping - process_shape, process_group) - # index: 2 - """ - from .reshard import _compute_partition_index - - partition_index = _compute_partition_index( - rank, complete_shape, dims_mapping, process_shape, process_group) - sliced_param_index = 0 - for i, shape in enumerate(complete_shape): - if dims_mapping[i] == -1: - slice_shape = shape - else: - slice_shape = shape // process_shape[dims_mapping[i]] - if shape == 1: - index = 0 - else: - index = (partition_index[i][0] + 1) // slice_shape - sliced_param_index = sliced_param_index * (shape // slice_shape) + index - return sliced_param_index - - -def _get_split_indices(complete_shape, dims_mapping, process_shape, - process_group): - """ - Get split indices of every dimension. - - Returns: - split_indices_list(list): the split indices of every dimension of the parameter - - Examples: - .. code-block:: python - - import numpy as np - complete_param = np.array([[[1.11, 1.12, 1.13, 1.14, 1.15, 1.16]]]) - complete_shape = [1, 1, 6] - dims_mapping = [-1, -1, 0] - process_shape = [3] - process_group = [0, 1, 2] - - index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group) - # index: [[], [], [2, 4]] - """ - from .reshard import _compute_partition_index - - split_indices_list = [] - for process in process_group: - partition_index = _compute_partition_index( - process, complete_shape, dims_mapping, process_shape, process_group) - if split_indices_list: - for dim in range(len(partition_index)): - split_indices_list[dim].extend(partition_index[dim]) - else: - split_indices_list = partition_index - split_indices_list = list( - map(lambda x, y: list(set(x) - set([y]) - set([0])), split_indices_list, - complete_shape)) - split_indices_list = [sorted(x) for x in split_indices_list] - return split_indices_list - - -def set_grad_var_shape(program, dist_context): - from .operators.common import infer_shape - from paddle.distributed.fleet.meta_optimizers.common import OpRole - - block = program.global_block() - vars = block.vars - for op in block.ops: - - if op.type in ["check_finite_and_unscale", "update_loss_scaling"]: - break - - if op.type in ["sum"]: - continue - if int(op.attr('op_role')) == int(OpRole.Backward): - op_dist_attr = dist_context.get_op_dist_attr_for_program(op) - assert op_dist_attr is not None - - for var_name in op.output_arg_names: - if "@GRAD" not in var_name: - continue - forward_var_name = var_name[:var_name.find("@GRAD")] - if op.type in [ - "c_allreduce_sum", "c_identity", "scale", "cast" - ]: - forward_var_name = op.input_arg_names[0] - elif op.type == "matmul_v2_grad": - forward_var_name = None - for output_name in op.output_names: - if var_name in op.output(output_name): - assert "@GRAD" in output_name - input_name = output_name[:output_name.find("@GRAD")] - assert len(op.input(input_name)) == 1 - forward_var_name = op.input(input_name)[0] - assert forward_var_name is not None - - need_set_shape_list = [ - "reshape2_grad", "softmax_with_cross_entropy_grad", - "transpose2_grad", "softmax_grad", "cross_entropy_grad2", - "dropout_grad" - ] - forward_list = [ - "reshape2", "softmax_with_cross_entropy", "transpose2", - "softmax", "cross_entropy2", "dropout" - ] - if op.type in need_set_shape_list: - for forward_op in block.ops: - assert int(forward_op.attr('op_role')) != int( - OpRole.Backward) - idx = need_set_shape_list.index(op.type) - forward_op_name = forward_list[idx] - if forward_op.type == forward_op_name and forward_var_name in forward_op.input_arg_names: - op_dist_attr = dist_context.get_op_dist_attr_for_program( - forward_op) - break - - forward_input_dist_attr = op_dist_attr.get_input_dist_attr( - forward_var_name) - - assert forward_input_dist_attr is not None, f"{forward_var_name}" - forward_var = vars[forward_var_name] - forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program( - forward_var) - assert forward_var_dist_attr is not None - grad_var = vars[var_name] - ref_shape = infer_shape(block, forward_var, - forward_var_dist_attr, - forward_input_dist_attr) - - if list(grad_var.shape) != ref_shape: - grad_var.desc.set_shape(ref_shape) - - -OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() -OpRole = core.op_proto_and_checker_maker.OpRole - - -def is_forward_op(op): - ref_role1 = int(core.op_proto_and_checker_maker.OpRole.Forward) - ref_role2 = int(core.op_proto_and_checker_maker.OpRole.Loss) - op_role = int(op.attr('op_role')) - return OP_ROLE_KEY in op.attr_names and (op_role == ref_role1 or - op_role == ref_role2) - - -def is_backward_op(op): - return OP_ROLE_KEY in op.attr_names and \ - int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward) - - -def is_loss_op(op): - return OP_ROLE_KEY in op.attr_names and \ - int(op.all_attrs()[OP_ROLE_KEY]) == (int(core.op_proto_and_checker_maker.OpRole.Forward) | int(core.op_proto_and_checker_maker.OpRole.Loss)) - - -def get_loss_op(block): - loss_ops = [] - for op in block.ops: - if is_loss_op(op): - assert len(op.desc.output_arg_names( - )) == 1, "loss op should only output loss var" - loss_ops.append(op) - - assert len(loss_ops) == 1, "num of loss op is not equal to one" - return loss_ops[0] - - -def set_var_dist_attr(dist_context, var, dims_mapping, process_mesh, **kwargs): - tensor_dist_attr = TensorDistributedAttribute() - tensor_dist_attr.dims_mapping = dims_mapping - # TODO get global mesh group - tensor_dist_attr.process_mesh = process_mesh - dist_context.set_tensor_dist_attr_for_program(var, tensor_dist_attr) - return tensor_dist_attr - - -def naive_set_dist_op_attr_for_program_by_mesh_and_mapping(new_op, process_mesh, - ref_mapping, ctx): - assert process_mesh is not None - assert ref_mapping is not None - - new_op_dist_attr = OperatorDistributedAttribute() - - for input_varname in new_op.desc.input_arg_names(): - new_op_dist_attr.set_input_dims_mapping(input_varname, ref_mapping) - for output_varname in new_op.desc.output_arg_names(): - new_op_dist_attr.set_output_dims_mapping(output_varname, ref_mapping) - - new_op_dist_attr.process_mesh = process_mesh - ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr) - - -def update_op_dims_mapping_by_default_dist_impl(dist_op): - changed = False - op_dist_attr = dist_op.dist_attr - op_desc = dist_op.serial_op.desc - # The following statement will be replaced by a more elegent way - if op_desc.type() == "shape" or op_desc.type() == "slice": - return False - output_names = op_desc.output_names() - xshape_arg_names = [] - if "XShape" in output_names: - xshape_arg_names = op_desc.output("XShape") - batch_dim_mappings = [] - for arg_name in op_desc.input_arg_names(): - serial_tensor = dist_op.get_serial_input(arg_name) - if serial_tensor.is_parameter: - continue - dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) - if len(dims_mapping) > 1: - for idx, mapping in enumerate(dims_mapping[1:]): - assert mapping == -1, \ - "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\ - .format(op_desc.type(), idx, mapping) - batch_dim_mappings.append(dims_mapping[0]) - for arg_name in op_desc.output_arg_names(): - serial_tensor = dist_op.get_serial_output(arg_name) - if serial_tensor.is_parameter: - continue - dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) - if arg_name not in xshape_arg_names: - if len(dims_mapping) > 1: - for idx, mapping in enumerate(dims_mapping[1:]): - assert mapping == -1, \ - "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\ - .format(op_desc.type(), idx, mapping) - batch_dim_mappings.append(dims_mapping[0]) - else: - assert dims_mapping[0] == -1, \ - "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension 0 is sharded by {} part."\ - .format(op_desc.type(), mapping) - if len(dims_mapping) > 2: - for idx, mapping in enumerate(dims_mapping[2:]): - assert mapping == -1, \ - "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension {} is sharded by {} part."\ - .format(op_desc.type(), idx, mapping) - batch_dim_mappings.append(dims_mapping[1]) - - compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings) - assert compatible_dim_mapping is not None, "There is no compatible dim mapping." - for arg_name in op_desc.input_arg_names(): - serial_tensor = dist_op.get_serial_input(arg_name) - if serial_tensor.is_parameter: - continue - dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) - if compatible_dim_mapping != dims_mapping[0]: - dims_mapping[0] = compatible_dim_mapping - changed = True - for arg_name in op_desc.output_arg_names(): - serial_tensor = dist_op.get_serial_output(arg_name) - if serial_tensor.is_parameter: - continue - dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) - if arg_name not in xshape_arg_names: - if compatible_dim_mapping != dims_mapping[0]: - dims_mapping[0] = compatible_dim_mapping - changed = True - else: - if compatible_dim_mapping != dims_mapping[1]: - dims_mapping[1] = compatible_dim_mapping - changed = True - - return changed - - -def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op): - changed = False - op_dist_attr = dist_op.dist_attr - op_desc = dist_op.serial_op.desc - input_arg_names = op_desc.input_arg_names() - input_dims_mapping_dict = {} - input_dims_mapping_lens = {} - max_dims_mapping_len = -1 - for arg_name in input_arg_names: - dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) - if max_dims_mapping_len < len(dims_mapping): - max_dims_mapping_len = len(dims_mapping) - input_dims_mapping_dict[arg_name] = dims_mapping - input_dims_mapping_lens[arg_name] = len(dims_mapping) - - dims_mapping_list = [] - for arg_name in input_arg_names: - if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: - new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)] - for i in range(input_dims_mapping_lens[arg_name]): - new_idx = (max_dims_mapping_len - - input_dims_mapping_lens[arg_name]) + i - new_dims_mapping[new_idx] = input_dims_mapping_dict[arg_name][i] - dims_mapping_list.append(new_dims_mapping) - else: - dims_mapping_list.append(input_dims_mapping_dict[arg_name]) - output_arg_names = op_desc.output_arg_names() - for arg_name in output_arg_names: - dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) - assert len(dims_mapping) == max_dims_mapping_len - dims_mapping_list.append(dims_mapping) - - compatible_dims_mapping = compute_compatible_dims_mapping(dims_mapping_list) - assert compatible_dims_mapping is not None, "There is no compatible dim mapping." - - for arg_name in input_arg_names: - if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: - new_dims_mapping = [ - -1 for _ in range(input_dims_mapping_lens[arg_name]) - ] - for i in range(input_dims_mapping_lens[arg_name]): - new_idx = (max_dims_mapping_len - - input_dims_mapping_lens[arg_name]) + i - new_dims_mapping[i] = compatible_dims_mapping[new_idx] - if new_dims_mapping != input_dims_mapping_dict[arg_name]: - op_dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping) - changed = True - else: - if compatible_dims_mapping != input_dims_mapping_dict[arg_name]: - op_dist_attr.set_input_dims_mapping(arg_name, - compatible_dims_mapping) - changed = True - - for arg_name in output_arg_names: - dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) - if compatible_dims_mapping != dims_mapping: - op_dist_attr.set_output_dims_mapping(arg_name, - compatible_dims_mapping) - changed = True - - return changed - - -def get_all_distributed_main_program(serial_program_info, dist_context, - parallelizer): - "Get all distributed main programs by dist_context." - from .dist_context import DistributedOperatorContext, DistributedContext - cluster = serial_program_info.cluster - copied_parallelizer = copy.deepcopy(parallelizer) - all_dist_main_program = [] - ranks = paddle.distributed.get_world_size() if cluster is None else len( - cluster.get_all_devices("GPU")) - for rank_id in range(ranks): - used_dist_context = copy.deepcopy(dist_context) - used_dist_context._dist_op_context = DistributedOperatorContext() - _, _, dist_startup_program, dist_main_program, _ = copied_parallelizer._get_dist_program( - rank_id, used_dist_context) - all_dist_main_program.append(dist_main_program) - - return all_dist_main_program - - -class SerialProgramInfo: - def __init__(self, - train_program, - satrtup_program, - loss, - optimizer, - cluster=None): - self._train_program = train_program - self._startup_program = satrtup_program - self._loss = loss - self._optimizer = optimizer - self._cluster = cluster - - @property - def train_program(self): - return self._train_program - - @property - def startup_program(self): - return self._startup_program - - @property - def loss(self): - return self._loss - - @property - def optimizer(self): - return self._optimizer - - @property - def cluster(self): - return self._cluster - - -def get_standalone_cost_data(distributed_programs): - def _compute_runtime(op_cost, op, vars): - runtime = 0 - try: - runtime = float(op_cost["op_time"]) - except: - return runtime - op_config = op_cost["config"] - total_static_input_size = 0 - total_actual_input_size = 0 - parsed_info = op_config.split("\n") - variable = "(Variable)" - for info in parsed_info: - variable = "(Variable)" if "(Variable)" in info else "(list" - if variable in info: - arg_name_lower = info[:info.find(variable) - 1] - shape_left_boundary = info.find("[") - shape_right_boundary = info.find("]") - assert shape_left_boundary > 0 and shape_right_boundary > 0 and shape_right_boundary > shape_left_boundary, "Get shape failed." - shape = info[shape_left_boundary + 1: - shape_right_boundary].split(",") - shape = list(map(lambda x: int(x.strip()), shape)) - dtype_factor = 1 - total_static_input_size += reduce(lambda x, y: x * y, shape) - if op.type == "c_embedding": - arg_name_lower = "w" if arg_name_lower == "weight" else "ids" - for arg_name in op.input_names: - if arg_name.lower() == arg_name_lower: - for var_name in op.input(arg_name): - var = vars[var_name] - total_actual_input_size += reduce( - lambda x, y: x * y, var.shape) - break - assert total_static_input_size > 0 and total_actual_input_size > 0, "Get input size failed." - - actual_runtime = total_actual_input_size / total_static_input_size * runtime - return actual_runtime - - import paddle.cost_model as cm - cost_model = cm.CostModel() - cost_model.static_cost_data() - DEFAULT_MULTIPLE = 2 - OP_NAME_MAPPING = { - "c_embedding": "embedding", - "matmul_v2": "matmul", - "transpose2": "transpose", - "reshape2": "reshape", - "unsqueeze2": "unsqueeze", - "reduce_sum": "sum", - "elementwise_div": "divide" - } - - standalone_cost_data = [] - not_enum_ops = ["create_py_reader", "create_double_buffer_reader", "read"] - for distributed_program in distributed_programs: - cost_data = {} - vars = distributed_program.global_block().vars - for op in distributed_program.global_block().ops: - runtime = 0 - if op.type in not_enum_ops: - cost_data[op.desc.id()] = runtime - continue - dtype = str(vars[op.input_arg_names[0]] - .dtype) if op.input_arg_names else "float32" - if int(op.attr('op_role')) == int(OpRole.Backward): - if "_grad" in op.type: - forward_op_name = op.type[:-5] - if forward_op_name in OP_NAME_MAPPING.keys(): - forward_op_name = OP_NAME_MAPPING[forward_op_name] - op_cost = cost_model.get_static_op_time( - forward_op_name, forward=False, dtype=dtype) - if op_cost: - runtime = _compute_runtime(op_cost, op, vars) - else: - op_cost = cost_model.get_static_op_time( - forward_op_name, dtype=dtype) - if op_cost: - runtime = 2 * _compute_runtime(op_cost, op, vars) - elif int(op.attr('op_role')) == int(OpRole.Forward): - op_name = OP_NAME_MAPPING[ - op.type] if op.type in OP_NAME_MAPPING.keys() else op.type - op_cost = cost_model.get_static_op_time(op_name) - if op_cost: - runtime = _compute_runtime(op_cost, op, vars) - - cost_data[op.desc.id()] = runtime - - standalone_cost_data.append(cost_data) - - return standalone_cost_data - - -def set_dist_op_desc_original_id(dist_op_desc, op_desc, dist_context): - op_id = op_desc.id() - op_original_id = op_desc.original_id() - # First, try to set the original id to the id of the op_desc - if op_id in dist_context._dist_ops_for_program: - dist_op_desc.set_original_id(op_id) - return - # Second, try to set the original id to the original_id of the op_desc - elif op_original_id in dist_context._dist_ops_for_program: - dist_op_desc.set_original_id(op_original_id) - return - # Third, print error infomation if we cannot find the original id - else: - assert False, "Cannot find the original id in the distributed context" diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py index 7047f162aad46..05d71aca5db2c 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py @@ -76,13 +76,15 @@ def forward(self, input): auto.shard_tensor( self.linear0.weight, dist_attr={ - "process_mesh": [0, 1, 2, 3], - "dims_mapping": [-1, 0] + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] }) auto.shard_tensor( self.linear1.weight, - dist_attr={"process_mesh": [1, 5], - "dims_mapping": [0, -1]}) + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [1, -1] + }) elif _global_parallel_strategy == "pp": auto.shard_tensor( self.linear0.weight, @@ -100,15 +102,7 @@ def forward(self, input): out = self.norm(input) out = self.linear0(out) out = F.gelu(out, approximate=True) - # if _global_parallel_strategy == "dp_mp": - # auto.shard_tensor( - # out, - # dist_attr={ - # "process_mesh": [1, 5], - # "dims_mapping": [-1, 1] - # }) out = self.linear1(out) - out = self.linear0(out) out = self.dropout(out) return out @@ -133,13 +127,12 @@ def mlp_pretrain_forward(train_program, start_program): "dims_mapping": [0, -1, -1] }) elif _global_parallel_strategy == "dp_mp": - pass - # auto.shard_tensor( - # input, - # dist_attr={ - # "process_mesh": [2, 4], - # "dims_mapping": [0, -1, -1] - # }) + auto.shard_tensor( + input, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1, -1] + }) mlp = MLPLayer( hidden_size=hidden_size, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py index 3575912a8b437..3a28595c833e0 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py @@ -436,13 +436,6 @@ def forward(self, input): out = self.linear2(out) out = F.gelu(out, approximate=True) out = self.linear3(out) - if _global_parallel_strategy == "dp_mp_pp": - auto.shard_tensor( - out, - dist_attr={ - "process_mesh": _global_process_mesh[1], - "dims_mapping": [0, -1] - }) return out @@ -494,9 +487,6 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): no_grad_set=None, callbacks=None) - print("1****************", rank_id, flush=True) - print_program_with_dist_attr(complete_train_program, dist_context) - partitioner = Partitioner(dist_context, rank_id) dist_train_program, dist_startup_prog, dist_params_grads = partitioner.partition( complete_train_program, startup_program, params_grads) @@ -504,11 +494,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): partitioned_optimize_ops = parallelizer._apply_optimize( dist_train_program, dist_startup_prog, dist_params_grads) - print("2****************", rank_id, flush=True) reshard(dist_train_program, dist_startup_prog, rank_id, dist_context) - print("3****************", rank_id, flush=True) - print_program_with_dist_attr(dist_train_program, dist_context) - print("4****************", rank_id, flush=True) return dist_train_program, dist_startup_prog @@ -551,8 +537,8 @@ def test_mapper_dp_mp_pp(self): dist_context = DistributedContext() dist_train_program, dist_startup_prog = get_dist_prog( train_program, startup_program, dist_context, rank_id) - if rank_id == 6: - print_program_with_dist_attr(dist_train_program, dist_context) + # if rank_id == 0: + # print_program_with_dist_attr(dist_train_program, dist_context) dist_programs[rank_id] = [dist_train_program, None] rank_mapping = mapping(dist_programs, cluster) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py index 78ad64b1dd852..ed64fa0630fa1 100755 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py @@ -154,7 +154,7 @@ def test_update(self): ops = train_program.global_block().ops vars = train_program.global_block().vars from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container - from paddle.distributed.auto_parallel.operators.common import is_elementwise_op + from paddle.distributed.auto_parallel.completion import is_elementwise_like_op from paddle.distributed.auto_parallel.dist_op import DistributedOperator for op in ops: @@ -163,7 +163,7 @@ def test_update(self): if dist_op_impl_container is None: op_dist_attr = dist_context.get_op_dist_attr_for_program(op) dist_op = DistributedOperator(op, op_dist_attr) - if is_elementwise_op(op.type): + if is_elementwise_like_op(op.type): changed = update_op_dims_mapping_by_elementwise_like_dist_impl( dist_op) self.assertFalse(changed) From fce7b08f12e3895a647bb84531953d6b52db31fc Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Tue, 18 Jan 2022 06:45:50 +0000 Subject: [PATCH 05/11] Recover some modifications --- .../auto_parallel/operators/__init__.py | 2 - .../distributed/auto_parallel/reshard.py | 1 - .../paddle/distributed/auto_parallel/rules.py | 94 -- .../paddle/distributed/auto_parallel/utils.py | 1415 +++++++++++++++++ 4 files changed, 1415 insertions(+), 97 deletions(-) delete mode 100644 python/paddle/distributed/auto_parallel/rules.py create mode 100644 python/paddle/distributed/auto_parallel/utils.py diff --git a/python/paddle/distributed/auto_parallel/operators/__init__.py b/python/paddle/distributed/auto_parallel/operators/__init__.py index 45854052dda4d..c28b7930124dd 100644 --- a/python/paddle/distributed/auto_parallel/operators/__init__.py +++ b/python/paddle/distributed/auto_parallel/operators/__init__.py @@ -22,8 +22,6 @@ from . import dist_reshape from . import dist_softmax from . import dist_transpose -from . import dist_eltwise -from . import dist_split from . import dist_default from . import dist_check_finite_and_unscale from . import dist_update_loss_scaling diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index d92f9739bf97b..6e6d2a672fd18 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -518,7 +518,6 @@ def _insert_recv_op(block, idx, tensor, src): def _insert_concat_op(block, idx, tensors, axis): """Insert concat op into block at the given block.""" inputs = {'X': tensors} - print("concat_op", inputs, flush=True) attrs = {} attrs['axis'] = axis helper = LayerHelper('concat', **locals()) diff --git a/python/paddle/distributed/auto_parallel/rules.py b/python/paddle/distributed/auto_parallel/rules.py deleted file mode 100644 index b48c90ec4ddd3..0000000000000 --- a/python/paddle/distributed/auto_parallel/rules.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy - - -def compute_compatible_process_mesh(process_mesh_list): - """Compute the compatible process mesh given a list of process meshes.""" - if not process_mesh_list: - return None - - def _compute_compatible_process_mesh_two(pm1, pm2): - if pm1 is None: - return True, pm2 - if pm2 is None: - return True, pm1 - if pm1 == pm2: - return True, pm1 - if pm1.processes == pm2.processes: - if len(pm1.topology) >= len(pm2.topology): - return True, pm1 - else: - return True, pm2 - process_set1 = set(pm1.processes) - process_set2 = set(pm2.processes) - if process_set1.issubset(process_set2): - return True, pm2 - if process_set2.issubset(process_set1): - return True, pm1 - return False, None - - compatible_result = None - for process_mesh in process_mesh_list: - compatible, compatible_result = _compute_compatible_process_mesh_two( - compatible_result, process_mesh) - if not compatible: - return None - return copy.deepcopy(compatible_result) - - -def compute_compatible_dim_mapping(dim_mapping_list): - """Compute the compatible dim mapping given a list of dim mapping.""" - if not dim_mapping_list: - return None - - def _compute_compatible_dim_mapping_two(dm1, dm2): - if dm1 == -1: - return True, dm2 - if dm2 == -1: - return True, dm1 - if dm1 == dm2: - return True, dm1 - return False, None - - compatible_result = -1 - for mapping in dim_mapping_list: - compatible, compatible_result = _compute_compatible_dim_mapping_two( - compatible_result, mapping) - if not compatible: - return None - return compatible_result - - -def compute_compatible_dims_mapping(dims_mapping_list): - """Compute the compatible dims mapping given a list of dims mapping. - Each of dims mapping is also a list. - """ - if not dims_mapping_list: - return None - length = len(dims_mapping_list[0]) - for dims_mapping in dims_mapping_list: - if dims_mapping is None: - return None - if len(dims_mapping) != length: - return None - compatible_result = [] - for dim_mappings in zip(*dims_mapping_list): - compatible_dim_mapping = compute_compatible_dim_mapping( - list(dim_mappings)) - if compatible_dim_mapping is None: - return None - compatible_result.append(compatible_dim_mapping) - return compatible_result diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py new file mode 100644 index 0000000000000..1867731974f11 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -0,0 +1,1415 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import os +import copy +import paddle +import threading +import numpy as np +import warnings +import logging +from functools import reduce + +import paddle.fluid.core as core +from paddle.framework.io import _to_LodTensor +from paddle.distributed.fleet.meta_optimizers.common import OpRole +from paddle.fluid.io import is_parameter, is_belong_to_optimizer +from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute + + +def is_valid_list_index(list, index): + if index >= -len(list) and index < len(list): + return True + else: + return False + + +def is_dim_shard(mapping): + if mapping != -1: + return True + else: + return False + + +def is_dim_replicate(mapping): + if mapping == -1: + return True + else: + return False + + +def compute_compatible_dim_mapping(dim_mappings): + if not dim_mappings: + return None + compatible_mapping = dim_mappings[0] + for mapping in dim_mappings: + if compatible_mapping == -1: + compatible_mapping = mapping + elif mapping == -1: + continue + elif compatible_mapping == mapping: + continue + else: + return None + return compatible_mapping + + +def compute_compatible_dims_mapping(dims_mapping_list): + if not dims_mapping_list: + return None + length = len(dims_mapping_list[0]) + for dims_mapping in dims_mapping_list: + assert dims_mapping is not None, \ + "Dims mapping must not be None for compatible computation" + assert len(dims_mapping) == length, \ + "The length of dims_mapping in list must be same for compatible computation." + compatible_result = [] + for dim_mappings in zip(*dims_mapping_list): + compatible_dim_mapping = compute_compatible_dim_mapping( + list(dim_mappings)) + if compatible_dim_mapping is None: + return None + compatible_result.append(compatible_dim_mapping) + return compatible_result + + +def compute_compatible_process_mesh(process_mesh_list): + compatible_process_mesh = None + if not process_mesh_list: + return compatible_process_mesh + for process_mesh in process_mesh_list: + if process_mesh is not None: + if compatible_process_mesh is None or compatible_process_mesh == process_mesh: + compatible_process_mesh = process_mesh + else: + return None + return compatible_process_mesh + + +def compute_compatible_and_update_dim_mapping(dims_mapping_list, index_list): + assert len(dims_mapping_list) == len(index_list) + changed = False + dim_mappings = [] + for i in range(len(dims_mapping_list)): + assert is_valid_list_index(dims_mapping_list[i], index_list[i]) + dim_mappings.append(dims_mapping_list[i][index_list[i]]) + compatible_dim_mapping = compute_compatible_dim_mapping(dim_mappings) + if compatible_dim_mapping is None: + return False + for i in range(len(dims_mapping_list)): + if compatible_dim_mapping != dims_mapping_list[i][index_list[i]]: + dims_mapping_list[i][index_list[i]] = compatible_dim_mapping + changed = True + return changed + + +def append_distributed_attr_suffix(name): + """ + Append auto parallel suffix for distributed attribute name. + """ + return name + core.kAutoParallelSuffix() + + +def remove_distributed_attr_suffix(name): + """ + Remove auto parallel suffix from distributed attribute name. + """ + return name.strip(core.kAutoParallelSuffix()) + + +def check_distributed_attr_for_program(program, dist_context=None): + from .dist_context import get_default_distributed_context + if dist_context is None: + dist_context = get_default_distributed_context() + assert dist_context.is_initialized_for_program(), \ + "Distributed attributes must be initialized before check." + for block in program.blocks: + for tensor in block.vars.values(): + dist_tensor = dist_context.get_dist_tensor_for_graph(tensor) + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( + tensor) + if (tensor_dist_attr is not None) and (not dist_tensor.is_valid()): + return False + for op in block.ops: + dist_op = dist_context.get_dist_op_for_graph(tensor) + op_dist_attr = dist_context.get_op_dist_attr_for_program(op) + if (op_dist_attr is not None) and (not dist_op.is_valid()): + return False + return True + + +def print_program_with_dist_attr(program, dist_context=None): + """ + This function reuses the original program output ability with a distributed context. + Using lock can avoid multiple threads change the default distributed context simultaneously. + """ + lock = threading.Lock() + lock.acquire() + from .dist_context import get_default_distributed_context + from .dist_context import set_default_distributed_context + if dist_context is None: + dist_context = get_default_distributed_context() + print(program) + else: + original_default_context = get_default_distributed_context() + set_default_distributed_context(dist_context) + print(program) + set_default_distributed_context(original_default_context) + lock.release() + + +def _get_comm_group(processes, shape, axis, rank): + """ + Given a rank and the processes mesh the rank belongs to, + compute the communication peers of the rank based on the give axis in the mesh. + + Example: 16 processes managed in a 4-Dimensinal mesh with shape of [2, 2, 2, 2]. + the rank communication peers of rank 0 (included) are following: + in axis 0: [0, 1] + in axis 1: [0, 2] + in axis 2: [0, 4] + in axis 3: [0, 8] + """ + + # NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous + # tricks to support processes mesh when it is not start with 0 or continuous + assert rank in processes, "rank [{}] is NOT in processes group {}".format( + rank, processes) + rank_relatvie = processes.index(rank) + coordinate = _linear_idx2coordinate(shape, rank_relatvie) + coordinates_in_group = [coordinate[:] for i in range(shape[axis])] + + # select comm group + for i in range(shape[axis]): + coordinates_in_group[i][axis] = i + + ranks_in_group_relative = [ + _coordinate2linear_idx(shape, coordinate) + for coordinate in coordinates_in_group + ] + ranks_in_group = [processes[idx] for idx in ranks_in_group_relative] + + return sorted(ranks_in_group) + + +def _get_idx_in_axis(processes, shape, axis, rank): + """ + Given a rank and the processes mesh the rank belongs to, + compute the index of the rank in given axis. + + Example: 27 processes managed in a 3-Dimensinal mesh with shape of [3, 3, 3]. + the index of rank 22 are: + in axis 0: 1 + in axis 1: 1 + in axis 2: 2 + """ + + # NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous + # tricks to support processes mesh when it is not start with 0 or continuous + rank_relatvie = processes.index(rank) + coordinate = _linear_idx2coordinate(shape, rank_relatvie) + return coordinate[axis] + + +def _coordinate2linear_idx(mesh_shape, coordinate): + """ + convert a coordinate in multidimensional mesh space into a scala idx in linear space. + + it use Row-major order for dimension conversion. + so it has: [most_significant_dim, ..., least_significant_dim] + assume: + + the size of i-th dimension to be: S[i] + the index of j-th dimension is: I[j] + + linear_idx of a n dimensional coordinate is: + + I[n-1] * (S[n-2] * S[n-3] * S[n-4] * .... S[0]) + + I[n-2] * ( S[n-3] * S[n-4] * .... S[0]) + + I[n-3] * ( S[n-4] * .... S[0]) + + ... + I[1] * ( S[0]) + + I[0] + + """ + # NOTE the following function work based on a strong an assumption + # that the processes in mesh are + # 1. starts from 0 + # 2. continuous + # it will be wrong if ths above condition doesnot meet, + # e.g. process_mesh = { process_groups = [7, 8, 9,10, 12, 13, 14, 15], mesh = [2, 4]} + # if you want a more general mapping, you should use cartesian product + + assert len(mesh_shape) == len( + coordinate + ), "coordinate should have the same size as mesh shape, but got shape: {}, coordinate: {}".format( + mesh_shape, coordinate) + for i in range(len(mesh_shape)): + assert coordinate[ + i] >= 0, "index in dimension [{}] is least than zero. coordinate: {}".format( + i, coordinate) + assert coordinate[i] < mesh_shape[ + i], "index beyond extent in dimension [{}]. shape: {}, coordinate: {}".format( + i, mesh_shape, coordinate) + + base = mesh_shape[-1] + linear_idx = coordinate[-1] + + # row major order + for i in range(len(mesh_shape) - 2, -1, -1): + linear_idx += base * coordinate[i] + base *= mesh_shape[i] + + return linear_idx + + +def _linear_idx2coordinate(mesh_shape, linear_idx): + """ + mapping a linear scala into multidimensional mesh space, return it coordinate in that space. + + it is the inverse function of _coordinate2linear_idx. + assume: + + the size of i-th dimension to be: S[i] + the index of j-th dimension is: I[j] + + the coordinate given linear_idx is: + + I[0] = linear_idx % S[0] + I[0] = (linear_idx / S[0]) % S[1] + I[0] = (linear_idx / (S[0] * S[1])) % S[2] + .... + + """ + + assert linear_idx >= 0, "linear index [{}] is least than zero".format( + linear_idx) + assert linear_idx < np.prod( + mesh_shape + ), "linear index beyond the extent of mesh shape. shape: {}, linear index: {}".format( + mesh_shape, linear_idx) + + base = 1 + coordinate = [-1] * len(mesh_shape) + + for i in reversed(range(len(mesh_shape))): + offset = linear_idx / base + coordinate[i] = int(offset % mesh_shape[i]) + base *= mesh_shape[i] + + # row major order + return coordinate + + +def _get_corresponding_rank(dist_context, target_mesh, rank): + + # TODO(JZ-LIANG) a hack method to support varying mesh in Pipeline parallelism case. + # we assume that all mesh are evenly divide from a parent mesh and should have same size. + # to revise this in future. + + coordinate = None + for mesh in dist_context.process_meshes: + if rank in mesh.processes and mesh.topology == target_mesh.topology: + coordinate = _linear_idx2coordinate(mesh.topology, + mesh.processes.index(rank)) + break + + assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format( + rank) + return target_mesh.processes[_coordinate2linear_idx(mesh.topology, + coordinate)] + + +def _get_unshard_dist_shape(var, dist_attr): + var_shape = var.shape + mapping = dist_attr.dims_mapping + mesh = dist_attr.process_mesh.topology + assert len(var_shape) == len( + mapping + ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format( + var_shape, mapping) + new_shape = [] + for idx in range(len(var_shape)): + if var_shape[idx] == -1 or mapping[idx] == -1: + new_shape.append(var_shape[idx]) + else: + new_shape.append(var_shape[idx] * mesh[mapping[idx]]) + + return new_shape + + +def make_data_unshard(dist_main_prog, dist_startup_prog, dist_context=None): + from .dist_context import get_default_distributed_context + if dist_context is None: + dist_context = get_default_distributed_context() + + for var in dist_main_prog.list_vars(): + if var.is_data: + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( + var) + inverse_shape = _get_unshard_dist_shape(var, tensor_dist_attr) + var.desc.set_shape(inverse_shape) + dim_mapping = tensor_dist_attr.dims_mapping + dim_mapping = [-1] * len(dim_mapping) + tensor_dist_attr.dims_mapping = dim_mapping + dist_context.set_tensor_dist_attr_for_program(var, tensor_dist_attr) + + +def _update_addition_info(addition_info): + """ Update default addition_info with inputs """ + add_info = {"epoch": 0, "batch": 0, "batch_size": 0} + if not addition_info: + return add_info + elif not isinstance(addition_info, dict): + raise TypeError("The type of 'addition_info' should be 'dict', " + "but got '{}'.".format(str(type(addition_info)))) + else: + for item, value in addition_info.items(): + if item not in ["epoch", "batch", "batch_size"]: + raise ValueError( + "The key of 'addition_info' should be one of the " + "['epoch', 'batch', 'batch_size'], but got '{}'." + .format(str(item))) + if not isinstance(value, int): + raise ValueError( + "The value of 'addition_info' should be 'int', " + "but got '{}'.".format(str(type(value)))) + add_info[item] = value + return add_info + + +def _check_valid_path(file_path): + """ Validity check of input file path """ + if not file_path: + return file_path + elif isinstance(file_path, list): + for file in file_path: + if not isinstance(file, str): + raise TypeError("The type of file path should be 'str', " + "but got '{}'.".format(str(type(file)))) + if not os.path.exists(file): + raise ValueError("The file path '{}' does not exist." + .format(file)) + return file_path + else: + raise TypeError("The type of file path should be 'list', " + "but got '{}'.".format(str(type(file_path)))) + + +def _check_param_dict(param_dict): + if not param_dict: + raise ValueError("'param_dict' cannot be None.") + elif not isinstance(param_dict, dict): + raise TypeError("The type of 'param_dict' should be 'dict', " + "but got '{}'.".format(str(type(param_dict)))) + else: + for name, value in param_dict.items(): + if not isinstance(name, str): + raise TypeError( + "The type of key of 'param_dict' should be 'str', " + "but got '{}'.".format(str(type(name)))) + if not isinstance(value, paddle.fluid.LoDTensor): + raise TypeError( + "The type of value of 'param_dict' should be 'LoDTensor', " + "but got '{}'.".format(str(type(value)))) + return param_dict + + +def _check_dist_attr(dist_attr): + if not dist_attr: + return dist_attr + elif not isinstance(dist_attr, dict): + raise TypeError("The type of 'dist_attr' should be 'dict', " + "but got '{}'.".format(str(type(dist_attr)))) + else: + for name, value in dist_attr.items(): + if not isinstance(name, str): + raise TypeError( + "The type of param name of 'dist_attr' should be 'str', " + "but got '{}'.".format(str(type(name)))) + if not isinstance(value, dict): + raise TypeError( + "The type of distributed attribute should be 'dict', " + "but got '{}'".format(str(type(value)))) + attr = ['process_shape', 'process_group', 'dims_mapping'] + if list(value.keys()) != attr: + raise ValueError( + "The key of distributed attribute should be " + "'['process_shape', 'process_group', 'dims_mapping']', " + "but got {}.".format(str(value.keys()))) + return dist_attr + + +def save_distributed_checkpoint(program, + checkpoint_path, + dist_attr_path, + addition_info=None, + is_integrated=False, + dist_context=None): + """ + Save model parameter state, optimzer state, distributed attribute and + additional information of each rank. + + Args: + program(Program): The program to be saved. + checkpoint_path(str): The path of the checkpoint file to be saved. + dist_attr_path(str): The path of distributed attribute file to be saved. + addition_info(dict, optional): Additional information, key should be selected in ['epoch', 'batch', 'batch_size']. + Default values are 0, when 'addition_info' is None. Default: None. + is_integrated(bool, optional): Whether to integrate param before save. Default: False. + dist_context(DistributedContext ,optional): collect related distributed information for program + + Returns: + None + + Examples: + .. code-block:: python + + path = os.path.join("./output", "step_%d" % step) + os.makedirs(path, exist_ok=True) + add_info = {'batch': step, "batch_size": global_batch_size} + save_distributed_checkpoint(program, path, path, add_info) + """ + from .dist_context import get_default_distributed_context + + assert isinstance(program, paddle.fluid.framework.Program) + assert isinstance(is_integrated, bool) + if dist_context is None: + dist_context = get_default_distributed_context() + addition_info = _update_addition_info(addition_info) + + if not is_integrated: + _save_distributed_state_dict(program, addition_info, checkpoint_path) + _save_distributed_attribute(program, dist_attr_path, dist_context) + else: + # TODO: integrate param before save + raise NotImplementedError( + "Integrating parameter has not been implemented.") + + +def load_distributed_checkpoint(checkpoint_path, dist_attr_path): + """ + Load parameter, optimizer, distributed attribute and addition_info. + + Args: + checkpoint_path(list[str]): model parameter file path, must be in order of rank id. + dist_attr_path(list[str]): distributed attribute file path, must be in order of rank id. + + Returns: + param_dict(dict): parameters' value of all ranks. + dist_attr(dict): parameters' distributed attribute. + addition_info(dict): additional information user saved in last training. + + Notes: + The return, 'addition_info', is belonging to the first file of checkpoint_path by default. + + Examples: + .. code-block:: python + + ckpt_path = ['./model_state_rank0.pdmodel', + './model_state_rank1.pdmodel'] + dist_attr_path = ['./dist_attr_rank0.pdattr', + './dist_attr_rank1.pdattr'] + param_dict, dist_attr, add_info = load_distributed_checkpoint(ckpt_path, dist_attr_path) + """ + assert _check_valid_path(checkpoint_path), \ + "'checkpoint_path' cannot be None." + assert _check_valid_path(dist_attr_path), \ + "'dist_attr_path' cannot be None." + + state_dict_info = _load_distributed_state_dict(checkpoint_path) + dist_attr = _load_distributed_attribute(dist_attr_path) + param_dict = state_dict_info["model"] + addition_info = state_dict_info["addition_info"] + return param_dict, dist_attr, addition_info + + +def load_checkpoint_into_program(checkpoint_path, + dist_attr_path, + program, + dist_context=None): + """ + Load parameter, optimizer, distributed attribute and addition_info into model. + + Args: + checkpoint_path(list[str]): model parameter file path, must be in order of rank id. + dist_attr_path(list[str]): distributed attribute file path, must be in order of rank id. + program(Program): the program to be updated with checkpoint_path. + dist_context(DistributedContext ,optional): collect related distributed information for program + + Returns: + addition_info(dict): user saved in last train. + + Notes: + The return, 'addition_info', is belonging to the first file of checkpoint_path by default. + + Examples: + .. code-block:: python + + exe.run(startup_program) + ckpt_path = ['./model_state_rank0.pdmodel', + './model_state_rank1.pdmodel'] + dist_attr_path = ['./dist_attr_rank0.pdattr', + './dist_attr_rank1.pdattr'] + load_checkpoint_into_program(ckpt_path, dist_attr_path, main_program) + """ + from .dist_context import get_default_distributed_context + + assert isinstance(program, paddle.fluid.framework.Program) + assert _check_valid_path(checkpoint_path), \ + "'checkpoint_path' cannot be None." + assert _check_valid_path(dist_attr_path), \ + "'dist_attr_path' cannot be None." + if dist_context is None: + dist_context = get_default_distributed_context() + all_state_dict_info = _load_distributed_state_dict(checkpoint_path) + all_pre_dist_attr = _load_distributed_attribute(dist_attr_path) + all_cur_dist_attr = get_dist_attr(program, dist_context) + all_param_dict = all_state_dict_info["model"] + addition_info = all_state_dict_info["addition_info"] + sliced_param_dict = merge_and_slice_parameter( + all_param_dict, all_pre_dist_attr, all_cur_dist_attr) + load_parameter_into_program(sliced_param_dict, program) + + return addition_info + + +def load_parameter_into_program(param_dict, program): + """ + Load parameters into program. + + Args: + param_dict(dict): parameters' name and value. + program(Program): the program to be updated + """ + _check_param_dict(param_dict) + assert program and isinstance(program, paddle.fluid.framework.Program) + program.set_state_dict(param_dict) + + +def _save_distributed_attribute(program, dist_attr_path, dist_context): + """ Save distributed attribute of all parameters """ + # TODO: just save a complete distributed attribute file + rank_id = paddle.distributed.get_rank() + dist_attr_name = os.path.join(dist_attr_path, + "dist_attr_rank{}.pdattr".format(rank_id)) + dist_attr_dict = { + "model": get_dist_attr(program, dist_context), + "world_size": paddle.distributed.get_world_size() + } + paddle.save(dist_attr_dict, dist_attr_name) + logging.info("Already saved distributed attribute to '{}'.".format( + dist_attr_path)) + + +def _load_distributed_attribute(dist_attr_path): + """ Load parameters' distributed attribute from dist_attr_path """ + total_dist_attr = {} + for dist_attr_file in dist_attr_path: + dist_attr = paddle.load(dist_attr_file) + pre_world_size = dist_attr["world_size"] + assert pre_world_size == len(dist_attr_path), \ + "The number of 'dist_attr_path' must be equal to the last training world size." + for name, attr in dist_attr["model"].items(): + if name not in total_dist_attr: + total_dist_attr[name] = attr + + return total_dist_attr + + +def _save_distributed_state_dict(program, addition_info, checkpoint_path): + """ Save parameters' state_dict """ + rank = paddle.distributed.get_rank() + ckpt_file_name = os.path.join(checkpoint_path, + "model_state_rank{}.pdmodel".format(rank)) + state_dict = { + "model": program.state_dict(), + "world_size": paddle.distributed.get_world_size(), + "addition_info": addition_info + } + paddle.save(state_dict, ckpt_file_name) + logging.info("Already saved model to '{}'.".format(checkpoint_path)) + + +def _load_distributed_state_dict(checkpoint_path): + """ Load parameters' state_dict from checkpoint_path """ + all_state_dict = {} + for idx, ckpt_file in enumerate(checkpoint_path): + state_dict_info = paddle.load(ckpt_file, return_numpy=True) + pre_world_size = state_dict_info["world_size"] + assert pre_world_size == len(checkpoint_path), \ + "The number of 'checkpoint_path' must be equal to the last training world size." + if idx == 0: + addition_info = state_dict_info["addition_info"] + for name, value in state_dict_info["model"].items(): + if name in all_state_dict: + all_state_dict[name].append(np.array(value)) + else: + all_state_dict[name] = [np.array(value)] + + all_state_dict_info = { + "model": all_state_dict, + "addition_info": addition_info + } + return all_state_dict_info + + +def get_dist_attr(program, dist_context=None): + """ + Get distributed attribute of current rank. + + Args: + program(Program): main program for training + """ + from .dist_context import get_default_distributed_context + + assert isinstance(program, paddle.fluid.framework.Program) + if dist_context is None: + dist_context = get_default_distributed_context() + dist_attr = {} + for var in program.list_vars(): + if is_parameter(var) or is_belong_to_optimizer(var): + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( + var) + process_mesh = tensor_dist_attr.process_mesh + dims_mapping = tensor_dist_attr.dims_mapping + dist_attr[var.name] = { + "process_shape": process_mesh.topology, + "process_group": process_mesh.processes, + "dims_mapping": dims_mapping + } + return dist_attr + + +def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr): + """ + Merge parameters with previous dist_attr and slice parameters with current dist_attr + + Arags: + dist_param_dict(dict): parameters' value of all ranks. + pre_dist_attr(dict): parameters' dist_attr of last training process. + cur_dist_attr(dict): parameters' dist_attr of current training process. + + Returns: + dist_param_dict(dict): parameters' value of current rank. + """ + assert _check_dist_attr(pre_dist_attr), "'pre_dist_attr' cannot be None." + assert _check_dist_attr(cur_dist_attr), "'pre_dist_attr' cannot be None." + assert isinstance(dist_param_dict, dict), \ + "The type of 'dist_param_dict' should be 'dict', but got {}.".format( + str(type(dist_param_dict))) + for name, value in dist_param_dict.items(): + if not isinstance(name, str): + raise TypeError("The key of 'dist_param_dict' is parameter's name, " + "and its type should be 'str', but got {}." + .format(str(type(name)))) + if not isinstance(value, list) or not all( + isinstance(v, np.ndarray) for v in value): + raise TypeError( + "The value of 'dist_param_dict' is parameter's value of all ranks, " + "and its type should be 'list(numpy.ndarray)'.") + + param_not_in_pre = [] + param_not_in_cur = [] + logging.info("Start to merge and slice parameters.") + for var_name in cur_dist_attr.keys(): + if var_name not in pre_dist_attr: + param_not_in_pre.append(var_name) + continue + + pre_attr = pre_dist_attr[var_name] + cur_attr = cur_dist_attr[var_name] + if pre_attr == cur_attr: + # skip merge and slice + rank_id = paddle.distributed.get_rank() + index = cur_attr["process_group"].index(rank_id) + param = dist_param_dict[var_name][index] + dist_param_dict[var_name] = _to_LodTensor(param) + continue + + pre_param = dist_param_dict[var_name] + pre_dims_mapping = pre_attr["dims_mapping"] + cur_dims_mapping = cur_attr["dims_mapping"] + if len(set(pre_dims_mapping)) > 1 or -1 not in pre_dims_mapping: + complete_param = _merge_parameter_with_dist_attr(pre_param, + pre_attr) + dist_param_dict[var_name] = complete_param + else: + complete_param = pre_param[0] + dist_param_dict[var_name] = _to_LodTensor(complete_param) + + if len(set(cur_dims_mapping)) > 1 or -1 not in cur_dims_mapping: + sliced_param = _slice_parameter_with_dist_attr(complete_param, + cur_attr) + dist_param_dict[var_name] = sliced_param + + for var_name in pre_dist_attr: + if var_name not in cur_dist_attr: + param_not_in_cur.append(var_name) + dist_param_dict.pop(var_name) + + if param_not_in_pre: + warnings.warn("Parameters '{}' are not found in last training process." + .format(str(param_not_in_pre))) + if param_not_in_cur: + warnings.warn( + "Parameters '{}' are not found in current training process." + .format(str(param_not_in_cur))) + + return dist_param_dict + + +def _merge_parameter_with_dist_attr(param_list, dist_attr): + """ Merge parameter with distributed attribute """ + from .reshard import _compute_complete_shape, _compute_partition_index + + dims_mapping = dist_attr["dims_mapping"] + process_shape = dist_attr["process_shape"] + process_group = dist_attr["process_group"] + # get the complete shape of the parameter + complete_shape = _compute_complete_shape(param_list[0].shape, process_shape, + dims_mapping) + # merge the parameter with dist_attr + partition_param_list = [] + merged_partiton = [] + for process in process_group: + partition_index = _compute_partition_index( + process, complete_shape, dims_mapping, process_shape, process_group) + index = process_group.index(process) + if partition_index not in merged_partiton: + merged_partiton.append(partition_index) + _merge_parameter(partition_param_list, param_list[index], + partition_index, complete_shape) + + assert len(partition_param_list) == 1 or not partition_param_list, \ + "Fail to merge parameter" + complete_param = _to_LodTensor(partition_param_list[0][0]) + return complete_param + + +def _slice_parameter_with_dist_attr(param, dist_attr): + """ Slice parameter with distributed attribute """ + param = np.array(param) if isinstance(param, + paddle.fluid.LoDTensor) else param + dims_mapping = dist_attr["dims_mapping"] + process_shape = dist_attr["process_shape"] + process_group = dist_attr["process_group"] + # slice the parameter with dist_attr + partition_index_list = _get_split_indices(param.shape, dims_mapping, + process_shape, process_group) + sliced_param_list = _slice_parameter(param, partition_index_list, + len(partition_index_list)) + # get the current parameter's index in sliced_param_list + rank_id = paddle.distributed.get_rank() + sliced_param_index = _get_sliced_param_index( + rank_id, param.shape, dims_mapping, process_shape, process_group) + sliced_param = _to_LodTensor(sliced_param_list[sliced_param_index]) + return sliced_param + + +def _merge_parameter(partition_param_list, param, partition_index, + complete_shape): + """ + Merge partitial parameters to a complete one. + + Returns: + None + + Examples: + .. code-block:: python + + import numpy as np + partition_param_list = [(np.array([[[1.11, 1.12]]]), [[0,1],[0,1],[0,2]])] + param = np.array([[[1.13, 1.14]]]) + partition_index = [[0,1],[0,1],[2,4]] + + _merge_parameter(partition_param_list, param, partition_index) + # partition_param_list: [(np.array([[[1.11, 1.12, 1.13, 1.14]]]), [[0,1],[0,1],[0,4]])] + """ + from .reshard import _compute_concat_info + + if len(partition_param_list) == 1: + is_complete_data = True + for idx, item in enumerate(partition_param_list[0][1]): + if item[0] != 0 or item[1] != complete_shape[idx]: + is_complete_data = False + break + if is_complete_data: + return + + if not partition_param_list: + partition_param_list.append((param, partition_index)) + else: + i = 0 + while i < len(partition_param_list): + concat_axis, first_order, new_partition = _compute_concat_info( + partition_param_list[i][1], partition_index) + if concat_axis != -1: + if first_order == 0: + new_param = np.concatenate( + (partition_param_list[i][0], param), axis=concat_axis) + else: + new_param = np.concatenate( + (param, partition_param_list[i][0]), axis=concat_axis) + + partition_param_list.pop(i) + _merge_parameter(partition_param_list, new_param, new_partition, + complete_shape) + break + i += 1 + + +def _slice_parameter(complete_param, partition_index_list, length): + """ + Slice a complete parameter. + + Returns: + sliced_param_list(list): sliced parameters with 'partition_index_list' + + Examples: + .. code-block:: python + + import numpy as np + complete_param = np.array([[[1.11, 1.12, 1.13, 1.14, 1.15, 1.16]]]) + rank = 2 + complete_shape = [1, 1, 6] + dims_mapping = [-1, -1, 0] + process_shape = [3] + process_group = [0, 1, 2] + + sliced_param_list = _slice_parameter(complete_param, [[], [], [2, 4]], 3) + # [array([[[1.11, 1.12]]]), array([[[1.13, 1.14]]]), array([[[1.15, 1.16]]])] + """ + sliced_param_list = [] + axis = len(complete_param.shape) - length + sliced_param = np.split( + complete_param, partition_index_list[axis], axis=axis) + if length == 1: + return sliced_param + for param in sliced_param: + sliced_param_list.extend( + _slice_parameter(param, partition_index_list, length - 1)) + return sliced_param_list + + +def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape, + process_group): + """ + Get sliced_param's index of current rank in all sliced parameters list. + + Returns: + sliced_param_index(int): the index of sliced param in sliced_param_list + + Examples: + .. code-block:: python + + import numpy as np + complete_param = np.array([[[1.11, 1.12, 1.13, 1.14, 1.15, 1.16]]]) + rank = 2 + complete_shape = [1, 1, 6] + dims_mapping = [-1, -1, 0] + process_shape = [3] + process_group = [0, 1, 2] + + slice_param = _slice_parameter(complete_param, [[], [], [2, 4]], 3) + # slice_param: + # [array([[[1.11, 1.12]]]), array([[[1.13, 1.14]]]), array([[[1.15, 1.16]]])] + + index = _get_sliced_param_index(rank, complete_shape, dims_mapping + process_shape, process_group) + # index: 2 + """ + from .reshard import _compute_partition_index + + partition_index = _compute_partition_index( + rank, complete_shape, dims_mapping, process_shape, process_group) + sliced_param_index = 0 + for i, shape in enumerate(complete_shape): + if dims_mapping[i] == -1: + slice_shape = shape + else: + slice_shape = shape // process_shape[dims_mapping[i]] + if shape == 1: + index = 0 + else: + index = (partition_index[i][0] + 1) // slice_shape + sliced_param_index = sliced_param_index * (shape // slice_shape) + index + return sliced_param_index + + +def _get_split_indices(complete_shape, dims_mapping, process_shape, + process_group): + """ + Get split indices of every dimension. + + Returns: + split_indices_list(list): the split indices of every dimension of the parameter + + Examples: + .. code-block:: python + + import numpy as np + complete_param = np.array([[[1.11, 1.12, 1.13, 1.14, 1.15, 1.16]]]) + complete_shape = [1, 1, 6] + dims_mapping = [-1, -1, 0] + process_shape = [3] + process_group = [0, 1, 2] + + index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group) + # index: [[], [], [2, 4]] + """ + from .reshard import _compute_partition_index + + split_indices_list = [] + for process in process_group: + partition_index = _compute_partition_index( + process, complete_shape, dims_mapping, process_shape, process_group) + if split_indices_list: + for dim in range(len(partition_index)): + split_indices_list[dim].extend(partition_index[dim]) + else: + split_indices_list = partition_index + split_indices_list = list( + map(lambda x, y: list(set(x) - set([y]) - set([0])), split_indices_list, + complete_shape)) + split_indices_list = [sorted(x) for x in split_indices_list] + return split_indices_list + + +def set_grad_var_shape(program, dist_context): + from .operators.common import infer_shape + from paddle.distributed.fleet.meta_optimizers.common import OpRole + + block = program.global_block() + vars = block.vars + for op in block.ops: + + if op.type in ["check_finite_and_unscale", "update_loss_scaling"]: + break + + if op.type in ["sum"]: + continue + if int(op.attr('op_role')) == int(OpRole.Backward): + op_dist_attr = dist_context.get_op_dist_attr_for_program(op) + assert op_dist_attr is not None + + for var_name in op.output_arg_names: + if "@GRAD" not in var_name: + continue + forward_var_name = var_name[:var_name.find("@GRAD")] + if op.type in [ + "c_allreduce_sum", "c_identity", "scale", "cast" + ]: + forward_var_name = op.input_arg_names[0] + elif op.type == "matmul_v2_grad": + forward_var_name = None + for output_name in op.output_names: + if var_name in op.output(output_name): + assert "@GRAD" in output_name + input_name = output_name[:output_name.find("@GRAD")] + assert len(op.input(input_name)) == 1 + forward_var_name = op.input(input_name)[0] + assert forward_var_name is not None + + need_set_shape_list = [ + "reshape2_grad", "softmax_with_cross_entropy_grad", + "transpose2_grad", "softmax_grad", "cross_entropy_grad2", + "dropout_grad" + ] + forward_list = [ + "reshape2", "softmax_with_cross_entropy", "transpose2", + "softmax", "cross_entropy2", "dropout" + ] + if op.type in need_set_shape_list: + for forward_op in block.ops: + assert int(forward_op.attr('op_role')) != int( + OpRole.Backward) + idx = need_set_shape_list.index(op.type) + forward_op_name = forward_list[idx] + if forward_op.type == forward_op_name and forward_var_name in forward_op.input_arg_names: + op_dist_attr = dist_context.get_op_dist_attr_for_program( + forward_op) + break + + forward_input_dist_attr = op_dist_attr.get_input_dist_attr( + forward_var_name) + + assert forward_input_dist_attr is not None, f"{forward_var_name}" + forward_var = vars[forward_var_name] + forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program( + forward_var) + assert forward_var_dist_attr is not None + grad_var = vars[var_name] + ref_shape = infer_shape(block, forward_var, + forward_var_dist_attr, + forward_input_dist_attr) + + if list(grad_var.shape) != ref_shape: + grad_var.desc.set_shape(ref_shape) + + +OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() +OpRole = core.op_proto_and_checker_maker.OpRole + + +def is_forward_op(op): + ref_role1 = int(core.op_proto_and_checker_maker.OpRole.Forward) + ref_role2 = int(core.op_proto_and_checker_maker.OpRole.Loss) + op_role = int(op.attr('op_role')) + return OP_ROLE_KEY in op.attr_names and (op_role == ref_role1 or + op_role == ref_role2) + + +def is_backward_op(op): + return OP_ROLE_KEY in op.attr_names and \ + int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward) + + +def is_loss_op(op): + return OP_ROLE_KEY in op.attr_names and \ + int(op.all_attrs()[OP_ROLE_KEY]) == (int(core.op_proto_and_checker_maker.OpRole.Forward) | int(core.op_proto_and_checker_maker.OpRole.Loss)) + + +def get_loss_op(block): + loss_ops = [] + for op in block.ops: + if is_loss_op(op): + assert len(op.desc.output_arg_names( + )) == 1, "loss op should only output loss var" + loss_ops.append(op) + + assert len(loss_ops) == 1, "num of loss op is not equal to one" + return loss_ops[0] + + +def set_var_dist_attr(dist_context, var, dims_mapping, process_mesh, **kwargs): + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.dims_mapping = dims_mapping + # TODO get global mesh group + tensor_dist_attr.process_mesh = process_mesh + dist_context.set_tensor_dist_attr_for_program(var, tensor_dist_attr) + return tensor_dist_attr + + +def naive_set_dist_op_attr_for_program_by_mesh_and_mapping(new_op, process_mesh, + ref_mapping, ctx): + assert process_mesh is not None + assert ref_mapping is not None + + new_op_dist_attr = OperatorDistributedAttribute() + + for input_varname in new_op.desc.input_arg_names(): + new_op_dist_attr.set_input_dims_mapping(input_varname, ref_mapping) + for output_varname in new_op.desc.output_arg_names(): + new_op_dist_attr.set_output_dims_mapping(output_varname, ref_mapping) + + new_op_dist_attr.process_mesh = process_mesh + ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr) + + +def update_op_dims_mapping_by_default_dist_impl(dist_op): + changed = False + op_dist_attr = dist_op.dist_attr + op_desc = dist_op.serial_op.desc + # The following statement will be replaced by a more elegent way + if op_desc.type() == "shape" or op_desc.type() == "slice": + return False + output_names = op_desc.output_names() + xshape_arg_names = [] + if "XShape" in output_names: + xshape_arg_names = op_desc.output("XShape") + batch_dim_mappings = [] + for arg_name in op_desc.input_arg_names(): + serial_tensor = dist_op.get_serial_input(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + if len(dims_mapping) > 1: + for idx, mapping in enumerate(dims_mapping[1:]): + assert mapping == -1, \ + "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\ + .format(op_desc.type(), idx, mapping) + batch_dim_mappings.append(dims_mapping[0]) + for arg_name in op_desc.output_arg_names(): + serial_tensor = dist_op.get_serial_output(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if arg_name not in xshape_arg_names: + if len(dims_mapping) > 1: + for idx, mapping in enumerate(dims_mapping[1:]): + assert mapping == -1, \ + "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\ + .format(op_desc.type(), idx, mapping) + batch_dim_mappings.append(dims_mapping[0]) + else: + assert dims_mapping[0] == -1, \ + "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension 0 is sharded by {} part."\ + .format(op_desc.type(), mapping) + if len(dims_mapping) > 2: + for idx, mapping in enumerate(dims_mapping[2:]): + assert mapping == -1, \ + "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension {} is sharded by {} part."\ + .format(op_desc.type(), idx, mapping) + batch_dim_mappings.append(dims_mapping[1]) + + compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings) + assert compatible_dim_mapping is not None, "There is no compatible dim mapping." + for arg_name in op_desc.input_arg_names(): + serial_tensor = dist_op.get_serial_input(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + if compatible_dim_mapping != dims_mapping[0]: + dims_mapping[0] = compatible_dim_mapping + changed = True + for arg_name in op_desc.output_arg_names(): + serial_tensor = dist_op.get_serial_output(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if arg_name not in xshape_arg_names: + if compatible_dim_mapping != dims_mapping[0]: + dims_mapping[0] = compatible_dim_mapping + changed = True + else: + if compatible_dim_mapping != dims_mapping[1]: + dims_mapping[1] = compatible_dim_mapping + changed = True + + return changed + + +def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op): + changed = False + op_dist_attr = dist_op.dist_attr + op_desc = dist_op.serial_op.desc + input_arg_names = op_desc.input_arg_names() + input_dims_mapping_dict = {} + input_dims_mapping_lens = {} + max_dims_mapping_len = -1 + for arg_name in input_arg_names: + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + if max_dims_mapping_len < len(dims_mapping): + max_dims_mapping_len = len(dims_mapping) + input_dims_mapping_dict[arg_name] = dims_mapping + input_dims_mapping_lens[arg_name] = len(dims_mapping) + + dims_mapping_list = [] + for arg_name in input_arg_names: + if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: + new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)] + for i in range(input_dims_mapping_lens[arg_name]): + new_idx = (max_dims_mapping_len - + input_dims_mapping_lens[arg_name]) + i + new_dims_mapping[new_idx] = input_dims_mapping_dict[arg_name][i] + dims_mapping_list.append(new_dims_mapping) + else: + dims_mapping_list.append(input_dims_mapping_dict[arg_name]) + output_arg_names = op_desc.output_arg_names() + for arg_name in output_arg_names: + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + assert len(dims_mapping) == max_dims_mapping_len + dims_mapping_list.append(dims_mapping) + + compatible_dims_mapping = compute_compatible_dims_mapping(dims_mapping_list) + assert compatible_dims_mapping is not None, "There is no compatible dim mapping." + + for arg_name in input_arg_names: + if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: + new_dims_mapping = [ + -1 for _ in range(input_dims_mapping_lens[arg_name]) + ] + for i in range(input_dims_mapping_lens[arg_name]): + new_idx = (max_dims_mapping_len - + input_dims_mapping_lens[arg_name]) + i + new_dims_mapping[i] = compatible_dims_mapping[new_idx] + if new_dims_mapping != input_dims_mapping_dict[arg_name]: + op_dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping) + changed = True + else: + if compatible_dims_mapping != input_dims_mapping_dict[arg_name]: + op_dist_attr.set_input_dims_mapping(arg_name, + compatible_dims_mapping) + changed = True + + for arg_name in output_arg_names: + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if compatible_dims_mapping != dims_mapping: + op_dist_attr.set_output_dims_mapping(arg_name, + compatible_dims_mapping) + changed = True + + return changed + + +def get_all_distributed_main_program(serial_program_info, dist_context, + parallelizer): + "Get all distributed main programs by dist_context." + from .dist_context import DistributedOperatorContext, DistributedContext + cluster = serial_program_info.cluster + copied_parallelizer = copy.deepcopy(parallelizer) + all_dist_main_program = [] + ranks = paddle.distributed.get_world_size() if cluster is None else len( + cluster.get_all_devices("GPU")) + for rank_id in range(ranks): + used_dist_context = copy.deepcopy(dist_context) + used_dist_context._dist_op_context = DistributedOperatorContext() + _, _, dist_startup_program, dist_main_program, _ = copied_parallelizer._get_dist_program( + rank_id, used_dist_context) + all_dist_main_program.append(dist_main_program) + + return all_dist_main_program + + +class SerialProgramInfo: + def __init__(self, + train_program, + satrtup_program, + loss, + optimizer, + cluster=None): + self._train_program = train_program + self._startup_program = satrtup_program + self._loss = loss + self._optimizer = optimizer + self._cluster = cluster + + @property + def train_program(self): + return self._train_program + + @property + def startup_program(self): + return self._startup_program + + @property + def loss(self): + return self._loss + + @property + def optimizer(self): + return self._optimizer + + @property + def cluster(self): + return self._cluster + + +def get_standalone_cost_data(distributed_programs): + def _compute_runtime(op_cost, op, vars): + runtime = 0 + try: + runtime = float(op_cost["op_time"]) + except: + return runtime + op_config = op_cost["config"] + total_static_input_size = 0 + total_actual_input_size = 0 + parsed_info = op_config.split("\n") + variable = "(Variable)" + for info in parsed_info: + variable = "(Variable)" if "(Variable)" in info else "(list" + if variable in info: + arg_name_lower = info[:info.find(variable) - 1] + shape_left_boundary = info.find("[") + shape_right_boundary = info.find("]") + assert shape_left_boundary > 0 and shape_right_boundary > 0 and shape_right_boundary > shape_left_boundary, "Get shape failed." + shape = info[shape_left_boundary + 1: + shape_right_boundary].split(",") + shape = list(map(lambda x: int(x.strip()), shape)) + dtype_factor = 1 + total_static_input_size += reduce(lambda x, y: x * y, shape) + if op.type == "c_embedding": + arg_name_lower = "w" if arg_name_lower == "weight" else "ids" + for arg_name in op.input_names: + if arg_name.lower() == arg_name_lower: + for var_name in op.input(arg_name): + var = vars[var_name] + total_actual_input_size += reduce( + lambda x, y: x * y, var.shape) + break + assert total_static_input_size > 0 and total_actual_input_size > 0, "Get input size failed." + + actual_runtime = total_actual_input_size / total_static_input_size * runtime + return actual_runtime + + import paddle.cost_model as cm + cost_model = cm.CostModel() + cost_model.static_cost_data() + DEFAULT_MULTIPLE = 2 + OP_NAME_MAPPING = { + "c_embedding": "embedding", + "matmul_v2": "matmul", + "transpose2": "transpose", + "reshape2": "reshape", + "unsqueeze2": "unsqueeze", + "reduce_sum": "sum", + "elementwise_div": "divide" + } + + standalone_cost_data = [] + not_enum_ops = ["create_py_reader", "create_double_buffer_reader", "read"] + for distributed_program in distributed_programs: + cost_data = {} + vars = distributed_program.global_block().vars + for op in distributed_program.global_block().ops: + runtime = 0 + if op.type in not_enum_ops: + cost_data[op.desc.id()] = runtime + continue + dtype = str(vars[op.input_arg_names[0]] + .dtype) if op.input_arg_names else "float32" + if int(op.attr('op_role')) == int(OpRole.Backward): + if "_grad" in op.type: + forward_op_name = op.type[:-5] + if forward_op_name in OP_NAME_MAPPING.keys(): + forward_op_name = OP_NAME_MAPPING[forward_op_name] + op_cost = cost_model.get_static_op_time( + forward_op_name, forward=False, dtype=dtype) + if op_cost: + runtime = _compute_runtime(op_cost, op, vars) + else: + op_cost = cost_model.get_static_op_time( + forward_op_name, dtype=dtype) + if op_cost: + runtime = 2 * _compute_runtime(op_cost, op, vars) + elif int(op.attr('op_role')) == int(OpRole.Forward): + op_name = OP_NAME_MAPPING[ + op.type] if op.type in OP_NAME_MAPPING.keys() else op.type + op_cost = cost_model.get_static_op_time(op_name) + if op_cost: + runtime = _compute_runtime(op_cost, op, vars) + + cost_data[op.desc.id()] = runtime + + standalone_cost_data.append(cost_data) + + return standalone_cost_data + + +def set_dist_op_desc_original_id(dist_op_desc, op_desc, dist_context): + op_id = op_desc.id() + op_original_id = op_desc.original_id() + # First, try to set the original id to the id of the op_desc + if op_id in dist_context._dist_ops_for_program: + dist_op_desc.set_original_id(op_id) + return + # Second, try to set the original id to the original_id of the op_desc + elif op_original_id in dist_context._dist_ops_for_program: + dist_op_desc.set_original_id(op_original_id) + return + # Third, print error infomation if we cannot find the original id + else: + assert False, "Cannot find the original id in the distributed context" From 9ba759998cb7615f698748b2ca0f86e6f427d901 Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Tue, 18 Jan 2022 07:35:02 +0000 Subject: [PATCH 06/11] Add lost files --- .../auto_parallel/operators/__init__.py | 1 + .../auto_parallel/operators/dist_eltwise.py | 170 ++++++++++++++++++ 2 files changed, 171 insertions(+) create mode 100755 python/paddle/distributed/auto_parallel/operators/dist_eltwise.py diff --git a/python/paddle/distributed/auto_parallel/operators/__init__.py b/python/paddle/distributed/auto_parallel/operators/__init__.py index c28b7930124dd..ea743df8d643b 100644 --- a/python/paddle/distributed/auto_parallel/operators/__init__.py +++ b/python/paddle/distributed/auto_parallel/operators/__init__.py @@ -23,5 +23,6 @@ from . import dist_softmax from . import dist_transpose from . import dist_default +from . import dist_eltwise from . import dist_check_finite_and_unscale from . import dist_update_loss_scaling diff --git a/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py b/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py new file mode 100755 index 0000000000000..7d33692e46af9 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py @@ -0,0 +1,170 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from .common import DistributedOperatorImplContainer +from .common import DistributedOperatorImpl +from .common import register_distributed_operator_impl_container +from .common import register_distributed_operator_impl +from .common import is_elementwise_op +from ..utils import is_dim_shard +from ..utils import is_dim_replicate +from ..utils import is_valid_list_index +from ..utils import compute_compatible_dim_mapping +from ..utils import compute_compatible_dims_mapping +from ..utils import compute_compatible_and_update_dim_mapping +from ..dist_attribute import OperatorDistributedAttribute +from paddle.fluid import core, unique_name +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.framework import Program, Parameter, Variable, program_guard +from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype +from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY +from ..process_group import new_process_group +from ..utils import _get_comm_group, _get_corresponding_rank +from .dist_default import DistributedDefaultImpl0 + + +class DistributedElementwise(DistributedOperatorImplContainer): + def __init__(self, op_type): + super(DistributedElementwise, self).__init__(op_type) + + +register_distributed_operator_impl_container( + DistributedElementwise("elementwise")) + + +# Replicated Elementwise +class DistributedElementwiseImpl0(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedElementwiseImpl0, self).__init__(name) + self._forward_implemented = False + self._backward_implemented = False + + def is_input_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + if is_elementwise_op(op_desc.type()): + return True + else: + return False + + def is_output_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_desc = dist_op.serial_op.desc + if is_elementwise_op(op_desc.type()): + return True + else: + return False + + def is_auto_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + dims_mapping_list = [] + input_arg_names = op_desc.input_arg_names() + max_dims_mapping_len = -1 + for arg_name in input_arg_names: + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + if max_dims_mapping_len < len(dims_mapping): + max_dims_mapping_len = len(dims_mapping) + dims_mapping_list.append(dims_mapping) + output_arg_names = op_desc.output_arg_names() + for arg_name in output_arg_names: + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + assert len(dims_mapping) == max_dims_mapping_len + dims_mapping_list.append(dims_mapping) + + for idx in range(max_dims_mapping_len): + dim_mappings = [] + for dims_mapping in dims_mapping_list: + if idx < len(dims_mapping): + dim_mappings.append(dims_mapping[-(idx + 1)]) + if not all(dim_mappings[0] == dim_mapping + for dim_mapping in dim_mappings): + return False + return True + + def update_dims_mapping(self, dist_op): + changed = False + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + input_arg_names = op_desc.input_arg_names() + input_dims_mapping_dict = {} + input_dims_mapping_lens = {} + max_dims_mapping_len = -1 + for arg_name in input_arg_names: + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + if max_dims_mapping_len < len(dims_mapping): + max_dims_mapping_len = len(dims_mapping) + input_dims_mapping_dict[arg_name] = dims_mapping + input_dims_mapping_lens[arg_name] = len(dims_mapping) + + dims_mapping_list = [] + for arg_name in input_arg_names: + if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: + new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)] + for i in range(input_dims_mapping_lens[arg_name]): + new_idx = (max_dims_mapping_len - + input_dims_mapping_lens[arg_name]) + i + new_dims_mapping[new_idx] = input_dims_mapping_dict[ + arg_name][i] + dims_mapping_list.append(new_dims_mapping) + else: + dims_mapping_list.append(input_dims_mapping_dict[arg_name]) + output_arg_names = op_desc.output_arg_names() + for arg_name in output_arg_names: + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + assert len(dims_mapping) == max_dims_mapping_len + dims_mapping_list.append(dims_mapping) + + compatible_dims_mapping = compute_compatible_dims_mapping( + dims_mapping_list) + assert compatible_dims_mapping is not None, "There is no compatible dim mapping." + + for arg_name in input_arg_names: + if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: + new_dims_mapping = [ + -1 for _ in range(input_dims_mapping_lens[arg_name]) + ] + for i in range(input_dims_mapping_lens[arg_name]): + new_idx = (max_dims_mapping_len - + input_dims_mapping_lens[arg_name]) + i + new_dims_mapping[i] = compatible_dims_mapping[new_idx] + if new_dims_mapping != input_dims_mapping_dict[arg_name]: + op_dist_attr.set_input_dims_mapping(arg_name, + new_dims_mapping) + changed = True + else: + if compatible_dims_mapping != input_dims_mapping_dict[arg_name]: + op_dist_attr.set_input_dims_mapping(arg_name, + compatible_dims_mapping) + changed = True + + for arg_name in output_arg_names: + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if compatible_dims_mapping != dims_mapping: + op_dist_attr.set_output_dims_mapping(arg_name, + compatible_dims_mapping) + changed = True + + return changed + + @staticmethod + def forward(ctx, *args, **kwargs): + DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + + @staticmethod + def backward(ctx, *args, **kwargs): + DistributedDefaultImpl0.backward(ctx, *args, **kwargs) + + +register_distributed_operator_impl( + "elementwise", DistributedElementwiseImpl0("replicate_parallel")) From 45939416da5dce81bf8be781264f12978ffe70b7 Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Tue, 18 Jan 2022 09:24:31 +0000 Subject: [PATCH 07/11] Fix a minor bug --- python/paddle/distributed/auto_parallel/operators/common.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index fc1a1626dad6b..4b079e7b6b575 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -131,9 +131,6 @@ def forward(dist_ctx, *args, **kwargs): def backward(dist_ctx, *grad_outputs, **kwargs): raise NotImplementedError("Please Implement this method in Subclass.") - def is_auto_compatible(self, dist_op): - raise NotImplementedError("Please Implement this method in Subclass.") - def update_dims_mapping(self, dist_op): raise NotImplementedError("Please Implement this method in Subclass.") From d4371aa6b088ec6ae0fa7d27b9afde093cfb9f3f Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Wed, 19 Jan 2022 11:57:49 +0000 Subject: [PATCH 08/11] Fix the bug of the planner --- python/paddle/distributed/auto_parallel/planner.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/planner.py b/python/paddle/distributed/auto_parallel/planner.py index 5344af2a6698b..138ab93dbd898 100755 --- a/python/paddle/distributed/auto_parallel/planner.py +++ b/python/paddle/distributed/auto_parallel/planner.py @@ -250,7 +250,8 @@ def _enum_valid_dist_attr_for_op(program, op, process_mesh): op, dist_op.dist_attr, vars ) and PlanFilter.check_dims_mapping_for_special_op( op, dist_op.dist_attr, vars): - dist_op.dist_attr.impl_idx = -1 + dist_op.dist_attr.impl_type = "elementwise" + dist_op.dist_attr.impl_idx = 0 op_valid_dist_attrs.append(dist_op.dist_attr) continue else: @@ -266,7 +267,8 @@ def _enum_valid_dist_attr_for_op(program, op, process_mesh): op, dist_op.dist_attr, vars ) and PlanFilter.check_dims_mapping_for_special_op( op, dist_op.dist_attr, vars): - dist_op.dist_attr.impl_idx = -2 + dist_op.dist_attr.impl_type = "default" + dist_op.dist_attr.impl_idx = 0 op_valid_dist_attrs.append(dist_op.dist_attr) continue @@ -276,6 +278,7 @@ def _enum_valid_dist_attr_for_op(program, op, process_mesh): if impl.is_auto_compatible(dist_op): if PlanFilter.check_dims_mapping_for_op( op, dist_op.dist_attr, vars): + dist_op.dist_attr.impl_type = dist_op.serial_op.type dist_op.dist_attr.impl_idx = idx op_valid_dist_attrs.append(dist_op.dist_attr) @@ -290,7 +293,8 @@ def _enum_valid_dist_attr_for_op(program, op, process_mesh): for var_name in op.output_arg_names: op_dist_attr.set_output_dims_mapping( vars[var_name], [-1 for i in vars[var_name].shape]) - dist_op.dist_attr.impl_idx = -1 + dist_op.dist_attr.impl_type = "default" + dist_op.dist_attr.impl_idx = 0 op_valid_dist_attrs.append(dist_op.dist_attr) return op_valid_dist_attrs From 0e6d3648e4a8b4cd66035063311dfb31fad9ff22 Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Thu, 20 Jan 2022 02:40:13 +0000 Subject: [PATCH 09/11] Fix the format problem --- .../distributed/auto_parallel/dist_context.py | 78 +++++++++---------- .../distributed/auto_parallel/planner.py | 2 +- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index fbc092df9c3dc..ad3a53ff17d76 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -192,17 +192,17 @@ def get_tensor_dist_attr_for_graph(self, serial_tensor_node): else: return None - def set_tensor_dist_attr_for_graph(self, serial_tensor_node, dist_attr): - assert serial_tensor_node.is_var() and \ - serial_tensor_node.var() is not None - serial_tensor_id = serial_tensor_node.node.original_desc_id() - dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None) - assert dist_tensor is not None, \ - "The distributed tensor of the program has not been added to this context." - serial_tensor_node_id = serial_tensor_node.id() - new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor, - dist_attr) - self._dist_tensors_for_graph[serial_tensor_node_id] = new_dist_tensor + # def set_tensor_dist_attr_for_graph(self, serial_tensor_node, dist_attr): + # assert serial_tensor_node.is_var() and \ + # serial_tensor_node.var() is not None + # serial_tensor_id = serial_tensor_node.node.original_desc_id() + # dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None) + # assert dist_tensor is not None, \ + # "The distributed tensor of the program has not been added to this context." + # serial_tensor_node_id = serial_tensor_node.id() + # new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor, + # dist_attr) + # self._dist_tensors_for_graph[serial_tensor_node_id] = new_dist_tensor def get_op_dist_attr_for_program(self, serial_op): serial_op_id = serial_op.desc.id() @@ -236,34 +236,34 @@ def get_op_dist_attr_for_graph(self, serial_op_node): else: return None - def set_op_dist_attr_for_graph(self, serial_op_node, dist_attr): - assert serial_op_node.is_op() and \ - serial_op_node.op() is not None - serial_op_id = serial_op_node.node.original_desc_id() - dist_op = self._dist_ops_for_program.get(serial_op_id, None) - assert dist_op is not None, \ - "The distributed operator of the program has not been added to this context." - serial_op_node_id = serial_op_node.id() - new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr) - self._dist_ops_for_graph[serial_op_node_id] = new_dist_op - - def get_dist_attr_for_graph(self, serial_node): - if serial_node.is_var() and serial_node.var() is not None: - serial_tensor_node_id = serial_node.id() - dist_tensor = self._dist_tensors_for_graph.get( - serial_tensor_node_id, None) - if dist_tensor: - return dist_tensor.dist_attr - else: - return None - if serial_node.is_op() and serial_node.op() is not None: - serial_op_node_id = serial_node.id() - dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None) - if dist_op: - return dist_op.dist_attr - else: - return None - return None + # def set_op_dist_attr_for_graph(self, serial_op_node, dist_attr): + # assert serial_op_node.is_op() and \ + # serial_op_node.op() is not None + # serial_op_id = serial_op_node.node.original_desc_id() + # dist_op = self._dist_ops_for_program.get(serial_op_id, None) + # assert dist_op is not None, \ + # "The distributed operator of the program has not been added to this context." + # serial_op_node_id = serial_op_node.id() + # new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr) + # self._dist_ops_for_graph[serial_op_node_id] = new_dist_op + + # def get_dist_attr_for_graph(self, serial_node): + # if serial_node.is_var() and serial_node.var() is not None: + # serial_tensor_node_id = serial_node.id() + # dist_tensor = self._dist_tensors_for_graph.get( + # serial_tensor_node_id, None) + # if dist_tensor: + # return dist_tensor.dist_attr + # else: + # return None + # if serial_node.is_op() and serial_node.op() is not None: + # serial_op_node_id = serial_node.id() + # dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None) + # if dist_op: + # return dist_op.dist_attr + # else: + # return None + # return None def init_dist_attr_for_program(self): assert self._serial_program, \ diff --git a/python/paddle/distributed/auto_parallel/planner.py b/python/paddle/distributed/auto_parallel/planner.py index 138ab93dbd898..f7d4c734feea4 100755 --- a/python/paddle/distributed/auto_parallel/planner.py +++ b/python/paddle/distributed/auto_parallel/planner.py @@ -216,7 +216,7 @@ def _enum_valid_dist_attr_for_op(program, op, process_mesh): # compose dims mapping composed_dims_mapping_list = list( product( - * [dims_mapping_dict[key] for key in dims_mapping_dict.keys()])) + *[dims_mapping_dict[key] for key in dims_mapping_dict.keys()])) for composed_dims_mapping in composed_dims_mapping_list: op_dist_attr = OperatorDistributedAttribute() op_dist_attr.process_mesh = process_mesh From 0ec771f377711a50843cffb6ec5107bf0ad3979c Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Thu, 20 Jan 2022 11:17:15 +0000 Subject: [PATCH 10/11] [Auto Parallel] Update the completion algorithm --- .../distributed/auto_parallel/__init__.py | 6 - .../distributed/auto_parallel/completion.py | 1435 +++++++---------- .../distributed/auto_parallel/dist_context.py | 34 +- .../distributed/auto_parallel/parallelizer.py | 15 +- .../test_auto_parallel_completion.py | 66 +- .../test_auto_parallel_completion_gpt.py | 22 +- .../test_auto_parallel_cost_model.py | 6 +- .../test_auto_parallel_dist_tensor.py | 6 +- .../unittests/test_auto_parallel_mapper.py | 12 +- .../test_auto_parallel_partitioner.py | 6 +- .../test_auto_parallel_partitioner_gpt.py | 11 +- .../unittests/test_auto_parallel_reshard.py | 7 +- .../test_auto_parallel_reshard_dpmppp.py | 6 +- .../test_auto_parallel_reshard_mppp.py | 11 +- 14 files changed, 705 insertions(+), 938 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/__init__.py b/python/paddle/distributed/auto_parallel/__init__.py index 3b5ccaa062f6e..edcd53bdc7a52 100644 --- a/python/paddle/distributed/auto_parallel/__init__.py +++ b/python/paddle/distributed/auto_parallel/__init__.py @@ -15,12 +15,6 @@ from .interface import shard_tensor # noqa: F401 from .interface import shard_op # noqa: F401 from .process_mesh import ProcessMesh -# from .interface import set_shard_mask # noqa: F401 -# from .interface import set_offload_device # noqa: F401 -# from .interface import set_pipeline_stage # noqa: F401 -# from .interface import ProcessMesh # noqa: F401 -from .completion import complete_annotation # noqa: F401 -from .completion import complete_backward_annotation # noqa: F401 from .reshard import reshard # noqa: F401 from .cost_model import estimate_cost diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 660b1a54221a7..e2ece78ec7a92 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy from copy import deepcopy +import time from paddle.fluid import core from paddle.fluid import framework -from .utils import compute_compatible_process_mesh -from .utils import compute_compatible_dim_mapping -from .utils import compute_compatible_dims_mapping from .utils import print_program_with_dist_attr from .operators import find_best_compatible_distributed_operator_impl from .dist_context import get_default_distributed_context @@ -29,865 +28,623 @@ from .dist_attribute import OperatorDistributedAttribute from paddle.distributed.fleet.meta_optimizers.common import OpRole -ELEMENTWISE_LIKE_OP_LIST = ["elementwise_add", "gelu", "dropout", "cast"] +def compute_compatible_process_mesh(process_mesh_list): + """Compute the compatible process mesh given a list of process meshes.""" + if not process_mesh_list: + return None -def is_elementwise_like_op(op_type): - if op_type in ELEMENTWISE_LIKE_OP_LIST: - return True - else: - return False - + def _compute_compatible_process_mesh_two(pm1, pm2): + if pm1 is None: + return True, pm2 + if pm2 is None: + return True, pm1 + if pm1 == pm2: + return True, pm1 + if pm1.processes == pm2.processes: + if len(pm1.topology) >= len(pm2.topology): + return True, pm1 + else: + return True, pm2 + process_set1 = set(pm1.processes) + process_set2 = set(pm2.processes) + if process_set1.issubset(process_set2): + return True, pm2 + if process_set2.issubset(process_set1): + return True, pm1 + return False, None + + compatible_result = None + for process_mesh in process_mesh_list: + compatible, compatible_result = _compute_compatible_process_mesh_two( + compatible_result, process_mesh) + if not compatible: + return None + return copy.deepcopy(compatible_result) + + +def compute_compatible_dim_mapping(dim_mapping_list): + """Compute the compatible dim mapping given a list of dim mapping.""" + if not dim_mapping_list: + return None -def update_tensor_node_process_mesh(dist_context, tensor_node, fwd=True): - """ - Update tensor's process mesh by using its predecessor's process mesh if in the forward direction, - and by using its successor's process mesh if in the backward direction. Note: only the equal - process meshes are compatible for now. + def _compute_compatible_dim_mapping_two(dm1, dm2): + if dm1 == -1: + return True, dm2 + if dm2 == -1: + return True, dm1 + if dm1 == dm2: + return True, dm1 + return False, None + + compatible_result = -1 + for mapping in dim_mapping_list: + compatible, compatible_result = _compute_compatible_dim_mapping_two( + compatible_result, mapping) + if not compatible: + return None + return compatible_result + + +def compute_compatible_dims_mapping(dims_mapping_list): + """Compute the compatible dims mapping given a list of dims mapping. + Each of dims mapping is also a list. """ - changed = False - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(tensor_node) - if tensor_dist_attr.is_annotated("process_mesh"): - return changed - tensor_process_mesh = tensor_dist_attr.process_mesh - if fwd: - inputs_process_meshes = [] - for pred_op_node in tensor_node.inputs: - if pred_op_node.op() is not None: - op_dist_attr = dist_context.get_op_dist_attr_for_graph( - pred_op_node) - op_process_mesh = op_dist_attr.process_mesh - inputs_process_meshes.append(op_process_mesh) - compatible_process_mesh = compute_compatible_process_mesh( - inputs_process_meshes) - if compatible_process_mesh is not None and tensor_process_mesh is None: - tensor_dist_attr.process_mesh = compatible_process_mesh - changed = True - else: - outputs_process_meshes = [] - for succ_op_node in tensor_node.outputs: - if succ_op_node.op() is not None: - op_dist_attr = dist_context.get_op_dist_attr_for_graph( - succ_op_node) - op_process_mesh = op_dist_attr.process_mesh - outputs_process_meshes.append(op_process_mesh) - compatible_process_mesh = compute_compatible_process_mesh( - outputs_process_meshes) - if compatible_process_mesh is not None and tensor_process_mesh is None: - tensor_dist_attr.process_mesh = compatible_process_mesh - changed = True - return changed - - -def update_op_node_process_mesh(dist_context, op_node, fwd=True): - """ - Update op's process mesh by using its predecessor's process mesh if in the forward direction, - and by using its successor's process mesh if in the backward direction. Note: only the equal - process meshes are compatible for now. - """ - changed = False - op_dist_attr = dist_context.get_op_dist_attr_for_graph(op_node) - if op_dist_attr.is_annotated("process_mesh"): - return changed - op_process_mesh = op_dist_attr.process_mesh - if fwd: - inputs_process_meshes = [] - for tensor_node in op_node.inputs: - if tensor_node.var() is not None: - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( - tensor_node) - tensor_process_mesh = tensor_dist_attr.process_mesh - inputs_process_meshes.append(tensor_process_mesh) - compatible_process_mesh = compute_compatible_process_mesh( - inputs_process_meshes) - if compatible_process_mesh is not None and op_process_mesh is None: - op_dist_attr.process_mesh = compatible_process_mesh - changed = True - else: - outputs_process_meshes = [] - for tensor_node in op_node.outputs: - if tensor_node.var() is not None: - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( - tensor_node) - tensor_process_mesh = tensor_dist_attr.process_mesh - outputs_process_meshes.append(tensor_process_mesh) - compatible_process_mesh = compute_compatible_process_mesh( - outputs_process_meshes) - if compatible_process_mesh is not None and op_process_mesh is None: - op_dist_attr.process_mesh = compatible_process_mesh - changed = True - return changed - - -def update_op_dims_mapping_by_default_dist_impl(dist_context, op_node): - """Each operator has a default distributed operator, only allowed to be sharded in batch dimension.""" - changed = False - if (not op_node.is_op()) or (op_node.op() is None): - return False - op_desc = op_node.op() - dist_op = dist_context.get_dist_op_for_graph(op_node) - op_dist_attr = dist_op.dist_attr - # The following statement will be replaced by a more elegent way - if op_desc.type() == "shape" or op_desc.type() == "slice": - return False - output_names = op_desc.output_names() - xshape_arg_names = [] - if "XShape" in output_names: - xshape_arg_names = op_desc.output("XShape") - batch_dim_mappings = [] - for arg_name in op_desc.input_arg_names(): - serial_tensor = dist_op.get_serial_input(arg_name) - if serial_tensor.is_parameter: - continue - dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) - if len(dims_mapping) > 1: - for idx, mapping in enumerate(dims_mapping[1:]): - assert mapping == -1, \ - "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\ - .format(op_desc.type(), idx, mapping) - batch_dim_mappings.append(dims_mapping[0]) - for arg_name in op_desc.output_arg_names(): - serial_tensor = dist_op.get_serial_output(arg_name) - if serial_tensor.is_parameter: - continue - dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) - if arg_name not in xshape_arg_names: - if len(dims_mapping) > 1: - for idx, mapping in enumerate(dims_mapping[1:]): - assert mapping == -1, \ - "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\ - .format(op_desc.type(), idx, mapping) - batch_dim_mappings.append(dims_mapping[0]) - else: - assert dims_mapping[0] == -1, \ - "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension 0 is sharded by {} part."\ - .format(op_desc.type(), mapping) - if len(dims_mapping) > 2: - for idx, mapping in enumerate(dims_mapping[2:]): - assert mapping == -1, \ - "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension {} is sharded by {} part."\ - .format(op_desc.type(), idx, mapping) - batch_dim_mappings.append(dims_mapping[1]) - - compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings) - assert compatible_dim_mapping is not None, "There is no compatible dim mapping." - for arg_name in op_desc.input_arg_names(): - serial_tensor = dist_op.get_serial_input(arg_name) - if serial_tensor.is_parameter: - continue - dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) - if compatible_dim_mapping != dims_mapping[0]: - dims_mapping[0] = compatible_dim_mapping - changed = True - for arg_name in op_desc.output_arg_names(): - serial_tensor = dist_op.get_serial_output(arg_name) - if serial_tensor.is_parameter: - continue - dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) - if arg_name not in xshape_arg_names: - if compatible_dim_mapping != dims_mapping[0]: - dims_mapping[0] = compatible_dim_mapping + if not dims_mapping_list: + return None + length = len(dims_mapping_list[0]) + for dims_mapping in dims_mapping_list: + if dims_mapping is None: + return None + if len(dims_mapping) != length: + return None + compatible_result = [] + for dim_mappings in zip(*dims_mapping_list): + compatible_dim_mapping = compute_compatible_dim_mapping( + list(dim_mappings)) + if compatible_dim_mapping is None: + return None + compatible_result.append(compatible_dim_mapping) + return compatible_result + + +class Completer: + def __init__(self, dist_context): + assert dist_context is not None + self._dist_context = dist_context + + def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True): + changed = False + if (not tensor_node.is_var()) or (tensor_node.var() is None): + return False + tensor_desc = tensor_node.var() + # Skip reader tensor + if tensor_desc.type() == core.VarDesc.VarType.READER: + return False + tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph( + tensor_node) + assert tensor_dist_attr is not None + if tensor_dist_attr.is_annotated("dims_mapping"): + return False + tensor_dims_mapping = tensor_dist_attr.dims_mapping + if fwd: + dims_mapping_list = [] + for pred_op_node in tensor_node.inputs: + if pred_op_node.op() is not None: + if pred_op_node.op().type() == "create_py_reader" \ + or pred_op_node.op().type() == "create_double_buffer_reader" \ + or pred_op_node.op().type() == "read": + continue + op_dist_attr = self._dist_context.get_op_dist_attr_for_graph( + pred_op_node) + if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh: + op_dims_mapping = op_dist_attr.get_output_dims_mapping( + tensor_desc.name()) + dims_mapping_list.append(op_dims_mapping) + dims_mapping_list.append(tensor_dims_mapping) + compatible_dims_mapping = compute_compatible_dims_mapping( + dims_mapping_list) + if (compatible_dims_mapping is not None) and \ + (compatible_dims_mapping != tensor_dims_mapping): + tensor_dist_attr.dims_mapping = compatible_dims_mapping changed = True else: - if compatible_dim_mapping != dims_mapping[1]: - dims_mapping[1] = compatible_dim_mapping + dims_mapping_list = [] + for succ_op_node in tensor_node.outputs: + if succ_op_node.op() is not None: + if succ_op_node.op().type() == "create_py_reader" \ + or succ_op_node.op().type() == "create_double_buffer_reader" \ + or succ_op_node.op().type() == "read": + continue + op_dist_attr = self._dist_context.get_op_dist_attr_for_graph( + succ_op_node) + if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh: + op_dims_mapping = op_dist_attr.get_input_dims_mapping( + tensor_desc.name()) + dims_mapping_list.append(op_dims_mapping) + dims_mapping_list.append(tensor_dims_mapping) + compatible_dims_mapping = compute_compatible_dims_mapping( + dims_mapping_list) + if (compatible_dims_mapping is not None) and \ + (compatible_dims_mapping != tensor_dims_mapping): + tensor_dist_attr.dims_mapping = compatible_dims_mapping changed = True + return changed - return changed - - -def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_context, op_node): - """Element-wise operator can be sharded in any way (but should take care of broadcasting).""" - changed = False - if (not op_node.is_op()) or (op_node.op() is None): - return False - op_desc = op_node.op() - op_dist_attr = dist_context.get_op_dist_attr_for_graph(op_node) - - input_arg_names = op_desc.input_arg_names() - input_dims_mapping_dict = {} - input_dims_mapping_lens = {} - max_dims_mapping_len = -1 - for arg_name in input_arg_names: - dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) - if max_dims_mapping_len < len(dims_mapping): - max_dims_mapping_len = len(dims_mapping) - input_dims_mapping_dict[arg_name] = dims_mapping - input_dims_mapping_lens[arg_name] = len(dims_mapping) - - dims_mapping_list = [] - for arg_name in input_arg_names: - if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: - new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)] - for i in range(input_dims_mapping_lens[arg_name]): - new_idx = (max_dims_mapping_len - - input_dims_mapping_lens[arg_name]) + i - new_dims_mapping[new_idx] = input_dims_mapping_dict[arg_name][i] - dims_mapping_list.append(new_dims_mapping) - else: - dims_mapping_list.append(input_dims_mapping_dict[arg_name]) - output_arg_names = op_desc.output_arg_names() - for arg_name in output_arg_names: - dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) - assert len(dims_mapping) == max_dims_mapping_len - dims_mapping_list.append(dims_mapping) - - compatible_dims_mapping = compute_compatible_dims_mapping(dims_mapping_list) - assert compatible_dims_mapping is not None, "There is no compatible dim mapping." - - for arg_name in input_arg_names: - if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: - new_dims_mapping = [ - -1 for _ in range(input_dims_mapping_lens[arg_name]) - ] - for i in range(input_dims_mapping_lens[arg_name]): - new_idx = (max_dims_mapping_len - - input_dims_mapping_lens[arg_name]) + i - new_dims_mapping[i] = compatible_dims_mapping[new_idx] - if new_dims_mapping != input_dims_mapping_dict[arg_name]: - op_dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping) + def _update_op_node_dims_mapping(self, op_node, fwd=True): + changed = False + if (not op_node.is_op()) or (op_node.op() is None): + return False + # Skip reader op + op_desc = op_node.op() + if op_desc.type() == "create_py_reader" \ + or op_desc.type() == "create_double_buffer_reader" \ + or op_desc.type() == "read": + return False + dist_op = self._dist_context.get_dist_op_for_graph(op_node) + op_dist_attr = dist_op.dist_attr + if fwd: + for tensor_node in op_node.inputs: + if tensor_node.var() is not None: + if tensor_node.var().type() == core.VarDesc.VarType.READER: + continue + tensor_desc = tensor_node.var() + if op_dist_attr.is_annotated_input_dims_mapping( + tensor_desc.name()): + continue + tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph( + tensor_node) + if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh: + tensor_dims_mapping = tensor_dist_attr.dims_mapping + op_dims_mapping = op_dist_attr.get_input_dims_mapping( + tensor_desc.name()) + compatible_dims_mapping = compute_compatible_dims_mapping( + [op_dims_mapping, tensor_dims_mapping]) + if (compatible_dims_mapping is not None) and \ + (compatible_dims_mapping != op_dims_mapping): + op_dist_attr.set_input_dims_mapping( + tensor_desc.name(), compatible_dims_mapping) + changed = True + # Find the most compatible implemenetations from the distributed operator + op_dist_impl = find_best_compatible_distributed_operator_impl( + dist_op, fwd=True) + # print("fwd op dist impl0", op_node.id(), op_desc.type(), op_dist_impl) + # print("fwd op dist impl0", op_node.id(), op_desc.type(), op_dist_impl.type, + # op_dist_impl.idx, "op dist_attr", op_dist_attr.impl_type, + # op_dist_attr.impl_idx) + assert op_dist_impl is not None, "Cannot find the dist op implementation." + dim_changed = op_dist_impl.update_dims_mapping(dist_op) + if dim_changed: changed = True + if op_dist_impl.is_auto_compatible(dist_op): + # print("fwd op dist impl1", op_node.id(), op_desc.type(), op_dist_impl) + # print("fwd op dist impl1", op_node.id(), op_desc.type(), op_dist_impl.type, + # op_dist_impl.idx, "op dist_attr", op_dist_attr.impl_type, + # op_dist_attr.impl_idx) + if op_dist_impl.type == "elementwise": + op_dist_attr.impl_type = "default" + else: + op_dist_attr.impl_type = op_dist_impl.type + op_dist_attr.impl_idx = op_dist_impl.idx + # print("fwd op dist impl2", op_node.id(), op_desc.type(), op_dist_impl) + # print("fwd op dist impl2", op_node.id(), op_desc.type(), op_dist_impl.type, + # op_dist_impl.idx, "op dist_attr", op_dist_attr.impl_type, + # op_dist_attr.impl_idx) else: - if compatible_dims_mapping != input_dims_mapping_dict[arg_name]: - op_dist_attr.set_input_dims_mapping(arg_name, - compatible_dims_mapping) + for tensor_node in op_node.outputs: + if tensor_node.var() is not None: + if tensor_node.var().type() == core.VarDesc.VarType.READER: + continue + tensor_desc = tensor_node.var() + if op_dist_attr.is_annotated_output_dims_mapping( + tensor_desc.name()): + continue + tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph( + tensor_node) + if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh: + tensor_dims_mapping = tensor_dist_attr.dims_mapping + op_dims_mapping = op_dist_attr.get_output_dims_mapping( + tensor_desc.name()) + compatible_dims_mapping = compute_compatible_dims_mapping( + [op_dims_mapping, tensor_dims_mapping]) + if (compatible_dims_mapping is not None) and \ + (compatible_dims_mapping != op_dims_mapping): + op_dist_attr.set_output_dims_mapping( + tensor_desc.name(), compatible_dims_mapping) + changed = True + # Find the most compatible implemenetations from the distributed operator + op_dist_impl = find_best_compatible_distributed_operator_impl( + dist_op, fwd=False) + # print("bwd op dist impl0", op_node.id(), op_desc.type(), op_dist_impl.type, + # op_dist_impl.idx, "op dist_attr", op_dist_attr.impl_type, + # op_dist_attr.impl_idx) + assert op_dist_impl is not None, "Cannot find the dist op implementation." + dim_changed = op_dist_impl.update_dims_mapping(dist_op) + if dim_changed: changed = True + if op_dist_impl.is_auto_compatible(dist_op): + # print("bwd op dist impl1", op_node.id(), op_desc.type(), op_dist_impl.type, + # op_dist_impl.idx, "op dist_attr", op_dist_attr.impl_type, + # op_dist_attr.impl_idx) + if op_dist_impl.type == "elementwise": + op_dist_attr.impl_type = "default" + else: + op_dist_attr.impl_type = op_dist_impl.type + op_dist_attr.impl_idx = op_dist_impl.idx + # print("bwd op dist impl2", op_node.id(), op_desc.type(), op_dist_impl.type, + # op_dist_impl.idx, "op dist_attr", op_dist_attr.impl_type, + # op_dist_attr.impl_idx) + return changed - for arg_name in output_arg_names: - dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) - if compatible_dims_mapping != dims_mapping: - op_dist_attr.set_output_dims_mapping(arg_name, - compatible_dims_mapping) - changed = True - - return changed - - -def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): - changed = False - if (not tensor_node.is_var()) or (tensor_node.var() is None): - return False - tensor_desc = tensor_node.var() - # Skip reader tensor - if tensor_desc.type() == core.VarDesc.VarType.READER: - return False - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(tensor_node) - assert tensor_dist_attr is not None - if tensor_dist_attr.is_annotated("dims_mapping"): - return False - tensor_dims_mapping = tensor_dist_attr.dims_mapping - if fwd: - dims_mapping_list = [] - for pred_op_node in tensor_node.inputs: - if pred_op_node.op() is not None: - if pred_op_node.op().type() == "create_py_reader" \ - or pred_op_node.op().type() == "create_double_buffer_reader" \ - or pred_op_node.op().type() == "read": - continue - op_dist_attr = dist_context.get_op_dist_attr_for_graph( - pred_op_node) - op_dims_mapping = op_dist_attr.get_output_dims_mapping( - tensor_desc.name()) - dims_mapping_list.append(op_dims_mapping) - dims_mapping_list.append(tensor_dims_mapping) - compatible_dims_mapping = compute_compatible_dims_mapping( - dims_mapping_list) - if (compatible_dims_mapping is not None) and \ - (compatible_dims_mapping != tensor_dims_mapping): - tensor_dist_attr.dims_mapping = compatible_dims_mapping - changed = True - else: - dims_mapping_list = [] - for succ_op_node in tensor_node.outputs: - if succ_op_node.op() is not None: - if succ_op_node.op().type() == "create_py_reader" \ - or succ_op_node.op().type() == "create_double_buffer_reader" \ - or succ_op_node.op().type() == "read": - continue - op_dist_attr = dist_context.get_op_dist_attr_for_graph( - succ_op_node) - op_dims_mapping = op_dist_attr.get_input_dims_mapping( - tensor_desc.name()) - dims_mapping_list.append(op_dims_mapping) - dims_mapping_list.append(tensor_dims_mapping) - compatible_dims_mapping = compute_compatible_dims_mapping( - dims_mapping_list) - if (compatible_dims_mapping is not None) and \ - (compatible_dims_mapping != tensor_dims_mapping): - tensor_dist_attr.dims_mapping = compatible_dims_mapping - changed = True - return changed - - -def update_op_node_dims_mapping(dist_context, op_node, fwd=True): - changed = False - if (not op_node.is_op()) or (op_node.op() is None): - return False - # Skip reader op - op_desc = op_node.op() - if op_desc.type() == "create_py_reader" \ - or op_desc.type() == "create_double_buffer_reader" \ - or op_desc.type() == "read": - return False - dist_op = dist_context.get_dist_op_for_graph(op_node) - op_dist_attr = dist_op.dist_attr - if fwd: - for tensor_node in op_node.inputs: - if tensor_node.var() is not None: - if tensor_node.var().type() == core.VarDesc.VarType.READER: - continue - tensor_desc = tensor_node.var() - if op_dist_attr.is_annotated_input_dims_mapping( - tensor_desc.name()): - continue - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( - tensor_node) - tensor_dims_mapping = tensor_dist_attr.dims_mapping - op_dims_mapping = op_dist_attr.get_input_dims_mapping( - tensor_desc.name()) - compatible_dims_mapping = compute_compatible_dims_mapping( - [op_dims_mapping, tensor_dims_mapping]) - if (compatible_dims_mapping is not None) and \ - (compatible_dims_mapping != op_dims_mapping): - op_dist_attr.set_input_dims_mapping(tensor_desc.name(), - compatible_dims_mapping) - changed = True - # Find the most compatible implemenetations from the distributed operator - op_dist_impl = find_best_compatible_distributed_operator_impl( - dist_op, fwd=True) - assert op_dist_impl is not None, "Cannot find the dist op implementation." - dim_changed = op_dist_impl.update_dims_mapping(dist_op) - if dim_changed: - changed = True - if op_dist_impl.is_auto_compatible(dist_op): - if op_dist_impl.type == "elementwise": - op_dist_attr.impl_type = "default" - else: - op_dist_attr.impl_type = op_dist_impl.type - op_dist_attr.impl_idx = op_dist_impl.idx - else: - for tensor_node in op_node.outputs: - if tensor_node.var() is not None: - if tensor_node.var().type() == core.VarDesc.VarType.READER: - continue - tensor_desc = tensor_node.var() - if op_dist_attr.is_annotated_output_dims_mapping( - tensor_desc.name()): - continue - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( - tensor_node) - tensor_dims_mapping = tensor_dist_attr.dims_mapping - op_dims_mapping = op_dist_attr.get_output_dims_mapping( - tensor_desc.name()) - compatible_dims_mapping = compute_compatible_dims_mapping( - [op_dims_mapping, tensor_dims_mapping]) - if (compatible_dims_mapping is not None) and \ - (compatible_dims_mapping != op_dims_mapping): - op_dist_attr.set_output_dims_mapping( - tensor_desc.name(), compatible_dims_mapping) - changed = True - # Find the most compatible implemenetations from the distributed operator - op_dist_impl = find_best_compatible_distributed_operator_impl( - dist_op, fwd=False) - assert op_dist_impl is not None, "Cannot find the dist op implementation." - dim_changed = op_dist_impl.update_dims_mapping(dist_op) - if dim_changed: - changed = True - if op_dist_impl.is_auto_compatible(dist_op): - if op_dist_impl.type == "elementwise": - op_dist_attr.impl_type = "default" + def _update_process_mesh(self): + def _find_nearset_node(nodes, idx): + for node in reversed(nodes[:idx]): + node_dist_attr = self._dist_context.get_dist_attr_for_graph( + node) + if node_dist_attr.process_mesh is not None: + return node + + total_reach_fix_point = False + while not total_reach_fix_point: + total_changed = False + for is_fwd in [True, False]: + all_nodes = self._dist_context.serial_ordered_nodes \ + if is_fwd else reversed(self._dist_context.serial_ordered_nodes) + reach_fix_point = False + while not reach_fix_point: + changed = False + for idx, node in enumerate(all_nodes): + nearest_node = _find_nearset_node( + self._dist_context.serial_ordered_nodes, idx) + if nearest_node is None: + continue + nearest_node_dis_attr = self._dist_context.get_dist_attr_for_graph( + nearest_node) + nearest_process_mesh = nearest_node_dis_attr.process_mesh + cur_node_dist_attr = self._dist_context.get_dist_attr_for_graph( + node) + cur_process_mesh = cur_node_dist_attr.process_mesh + compatible_process_mesh = compute_compatible_process_mesh( + [cur_process_mesh, nearest_process_mesh]) + if compatible_process_mesh is not None \ + and cur_process_mesh != compatible_process_mesh: + cur_node_dist_attr.process_mesh = compatible_process_mesh + changed = True + if changed: + reach_fix_point = False + total_changed = True + else: + reach_fix_point = True + if total_changed: + total_reach_fix_point = False else: - op_dist_attr.impl_type = op_dist_impl.type - op_dist_attr.impl_idx = op_dist_impl.idx - return changed - - -def complete_annotation(program, dist_context=None): - """ Complete annotation for the partial annotated program. - - Arguments: - program: partial annotated program. - dist_context: the distributed context is used to store distributed attributes for program. - If not provided, the default one will be used. - Returns: - program: completed annotated program. - """ - - # Use the default distribted context for completeion if there is no one - if dist_context is None: - dist_context = get_default_distributed_context() - dist_context.serial_program = program - else: - dist_context.serial_program = program - - # print_program_with_dist_attr(program, dist_context) - - # Initialize distributed attributes for all var and op node in program - dist_context.init_dist_attr_for_program() - - # Initialize distributed attributes for all var and op node in graph - dist_context.init_dist_attr_for_graph() - - # Complete process mesh for each node - all_nodes = list(dist_context.serial_graph.all_nodes()) + total_reach_fix_point = True - def sort_key_fun(node): - first = -1 - if node.is_op(): - first = 0 - else: - first = 1 - second = -1 - if node.is_op() and node.op() is not None: - second = node.op().id() - if node.is_var() and node.var() is not None: - second = node.var().id() - return (first, second) - - all_nodes.sort(key=sort_key_fun) - - reach_fix_point = False - while not reach_fix_point: - total_changed = False - reach_fwd_fix_point = False - reach_bwd_fix_point = False - while not reach_fwd_fix_point: + def _update_dims_mapping(self): + # Complete dims_mapping for each node + reach_fix_point = False + while not reach_fix_point: changed = False - for node in all_nodes: - if node.is_var() and node.var() is not None: - tensor_changed = update_tensor_node_process_mesh( - dist_context, node, fwd=True) - if tensor_changed: - changed = True - if node.is_op() and node.op() is not None: - op_changed = update_op_node_process_mesh( - dist_context, node, fwd=True) - if op_changed: - changed = True + for is_fwd in [True, False]: + all_nodes = self._dist_context.serial_ordered_nodes \ + if is_fwd else reversed(self._dist_context.serial_ordered_nodes) + for node in all_nodes: + if node.is_var() and node.var() is not None: + tensor_changed = self._update_tensor_node_dims_mapping( + node, fwd=is_fwd) + if tensor_changed: + changed = True + if node.is_op() and node.op() is not None: + op_changed = self._update_op_node_dims_mapping( + node, fwd=is_fwd) + if op_changed: + changed = True if changed: - reach_fwd_fix_point = False - total_changed = True + reach_fix_point = False else: - reach_fwd_fix_point = True - while not reach_bwd_fix_point: - changed = False - for node in all_nodes: - if node.is_var() and node.var() is not None: - tensor_changed = update_tensor_node_process_mesh( - dist_context, node, fwd=False) - if tensor_changed: - changed = True - if node.is_op() and node.op() is not None: - op_changed = update_op_node_process_mesh( - dist_context, node, fwd=False) - if op_changed: - changed = True - if changed: - reach_bwd_fix_point = False - total_changed = True - else: - reach_bwd_fix_point = True - if total_changed: - reach_fix_point = False - else: - reach_fix_point = True - # Validation the completion of process meshes and should be moved to a proper location - is_wrong = False - for node in all_nodes: - if node.is_var() and node.var() is not None: - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( - node) - if tensor_dist_attr.process_mesh is None: - msg_str = "" - for op_node in node.inputs: - if op_node.op() is not None: - op_dist_attr = dist_context.get_op_dist_attr_for_graph( - op_node) - msg_str += "{} [{}], ".format( - op_node.op().type(), - op_dist_attr.process_mesh) - else: - msg_str += "{} [{}], ".format(op_node.name(), - None) - for op_node in node.outputs: - if op_node.op() is not None: - op_dist_attr = dist_context.get_op_dist_attr_for_graph( - op_node) - msg_str += "{} [{}], ".format( - op_node.op().type(), - op_dist_attr.process_mesh) - else: - msg_str += "{} [{}], ".format(op_node.name(), - None) - msg_str = "Cannot decide ProcessMesh of {} among {}. Please use shard_tensor api explicitly to annotate it".format( - node.var().name(), msg_str[:-2]) - is_wrong = True - print(msg_str) - if node.is_op() and node.op() is not None: - op_dist_attr = dist_context.get_op_dist_attr_for_graph(node) - if op_dist_attr.process_mesh is None: - msg_str = "" - for tensor_node in node.inputs: - if tensor_node.var() is not None: - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( - tensor_node) - msg_str += "{} [{}], ".format( - tensor_node.var().name(), - tensor_dist_attr.process_mesh) - else: - msg_str += "{} [{}], ".format( - tensor_node.name(), None) - for tensor_node in node.outputs: - if tensor_node.var() is not None: - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( - tensor_node) - msg_str += "{} [{}], ".format( - tensor_node.var().name(), - tensor_dist_attr.process_mesh) - else: - msg_str += "{} [{}], ".format( - tensor_node.name(), None) - msg_str = "Cannot decide ProcessMesh of {} among {}. Please use shard_op api explicitly to annotate it".format( - node.op().type(), msg_str[:-2]) - is_wrong = True - print(msg_str) - if node.is_op() and node.op() is None: - print("op op is None", node.name()) - if is_wrong: - assert False, "Cannot complete process_meshes of the program." - - # Complete dims_mapping for each node - reach_fix_point = False - while not reach_fix_point: - changed = False - for node in all_nodes: - if node.is_var() and node.var() is not None: - tensor_changed = update_tensor_node_dims_mapping( - dist_context, node, fwd=True) - if tensor_changed: - changed = True - if node.is_op() and node.op() is not None: - op_changed = update_op_node_dims_mapping( - dist_context, node, fwd=True) - if op_changed: - changed = True - for node in reversed(all_nodes): - if node.is_var() and node.var() is not None: - tensor_changed = update_tensor_node_dims_mapping( - dist_context, node, fwd=False) - if tensor_changed: - changed = True - if node.is_op() and node.op() is not None: - op_changed = update_op_node_dims_mapping( - dist_context, node, fwd=False) - if op_changed: - changed = True - if changed: - reach_fix_point = False - else: - reach_fix_point = True - - # Copy the corresponding distributed attribute from graph to program - dist_context.copy_dist_attr_from_graph_to_program() - dist_context.clear_dist_info_for_graph() - - # Do the validation check and amend some completion - dist_context.amend_dist_attr_for_program() - - # print_program_with_dist_attr(program, dist_context) - dist_context.validate_dist_attr_for_program() + reach_fix_point = True + + def complete_forward_annotation(self, serial_main_program): + """ Complete annotation for the partial annotated serial_main_program. + + Arguments: + serial_main_program: partial annotated serial_main_program. + + Returns: + serial_main_program: completed annotated serial_main_program. + """ + + # Use the default distribted context for completeion if there is no one + self._dist_context.serial_program = serial_main_program + + # Initialize distributed attributes for all var and op node in serial_main_program + self._dist_context.init_dist_attr_for_program() + + # Initialize distributed attributes for all var and op node in graph + self._dist_context.init_dist_attr_for_graph() + + self._update_process_mesh() + + # Complete dims_mapping for each node + self._update_dims_mapping() + + # Copy the corresponding distributed attribute from graph to serial_main_program + self._dist_context.copy_dist_attr_from_graph_to_program() + self._dist_context.clear_dist_info_for_graph() + + # print_serial_main_program_with_dist_attr(serial_main_program, self._dist_context) + # Do the validation check and amend some completion + self._dist_context.amend_dist_attr_for_program() + + # print_serial_main_program_with_dist_attr(serial_main_program, self._dist_context) + self._dist_context.validate_dist_attr_for_program() + + return serial_main_program + + def complete_backward_annotation(self, serial_main_program): + """Complete the annotation of vars and ops in the backward phase for parallel program.""" + + def _is_grad_var_name(name): + if "@GRAD" in name: + return True + return False + + def _get_forward_varname_from_grad_varname(grad_var_name): + assert _is_grad_var_name( + grad_var_name), "[{}] is not a grad varnme.".format( + grad_var_name) + return grad_var_name[:grad_var_name.find("@GRAD")] + + def _get_op_by_id(ops, id): + for op in ops: + if op.desc.id() == id: + return op + return None + + first_backward_op_idx = -1 + for idx, op in enumerate(serial_main_program.global_block().ops): + if int(op.attr('op_role')) == int( + int(core.op_proto_and_checker_maker.OpRole.Backward) | int( + core.op_proto_and_checker_maker.OpRole.Loss)): + assert op.type == "fill_constant" + first_backward_op_idx = idx + break + + assert first_backward_op_idx >= 0, "No backward procedure found in this program." + + ops = list(serial_main_program.global_block().ops) + vars = serial_main_program.global_block().vars + dist_op_context = self._dist_context.dist_op_context + + for idx in range(first_backward_op_idx, len(ops)): + + # complete the initial grad loss op + if idx == first_backward_op_idx: + assert ops[idx].type == "fill_constant" + assert len( + ops[idx].input_arg_names + ) == 0, "first backward op should has only ONE output, but got [{}]".format( + len(ops[idx].input_arg_names)) + assert len( + ops[idx].output_arg_names + ) == 1, "first backward op should has only ONE output, but got [{}]".format( + len(ops[idx].output_arg_names)) + + grad_var = vars[ops[idx].output_arg_names[0]] + forward_var_name = _get_forward_varname_from_grad_varname( + grad_var.name) + forward_var = vars[forward_var_name] + + # TODO complete other attribte for grad var + tensor_dist_attr = TensorDistributedAttribute() + process_mesh = self._dist_context.get_tensor_dist_attr_for_program( + forward_var).process_mesh + dims_mapping = self._dist_context.get_tensor_dist_attr_for_program( + forward_var).dims_mapping + tensor_dist_attr.dims_mapping = dims_mapping + tensor_dist_attr.process_mesh = process_mesh + self._dist_context.set_tensor_dist_attr_for_program( + grad_var, tensor_dist_attr) - return program - - -def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): - """Complete the annotation of vars and ops in the backward phase for parallel program.""" - - def _is_grad_var_name(name): - if "@GRAD" in name: - return True - return False - - def _get_forward_varname_from_grad_varname(grad_var_name): - assert _is_grad_var_name( - grad_var_name), "[{}] is not a grad varnme.".format(grad_var_name) - return grad_var_name[:grad_var_name.find("@GRAD")] - - def _get_op_by_id(ops, id): - for op in ops: - if op.desc.id() == id: - return op - return None + op_dist_attr = OperatorDistributedAttribute() + op_dist_attr.process_mesh = process_mesh + op_dist_attr.set_output_dims_mapping(grad_var.name, + dims_mapping) + self._dist_context.set_op_dist_attr_for_program(ops[idx], + op_dist_attr) + continue - if dist_context is None: - dist_context = get_default_distributed_context() - - first_backward_op_idx = -1 - for idx, op in enumerate(auto_parallel_main_prog.global_block().ops): - if int(op.attr('op_role')) == int( - int(core.op_proto_and_checker_maker.OpRole.Backward) | int( - core.op_proto_and_checker_maker.OpRole.Loss)): - assert op.type == "fill_constant" - first_backward_op_idx = idx - break - - assert first_backward_op_idx >= 0, "No backward procedure found in this program." - - ops = list(auto_parallel_main_prog.global_block().ops) - vars = auto_parallel_main_prog.global_block().vars - dist_op_context = dist_context.dist_op_context - - for idx in range(first_backward_op_idx, len(ops)): - - # complete the initial grad loss op - if idx == first_backward_op_idx: - assert ops[idx].type == "fill_constant" - assert len( - ops[idx].input_arg_names - ) == 0, "first backward op should has only ONE output, but got [{}]".format( - len(ops[idx].input_arg_names)) - assert len( - ops[idx].output_arg_names - ) == 1, "first backward op should has only ONE output, but got [{}]".format( - len(ops[idx].output_arg_names)) - - grad_var = vars[ops[idx].output_arg_names[0]] - forward_var_name = _get_forward_varname_from_grad_varname( - grad_var.name) - forward_var = vars[forward_var_name] - - # TODO complete other attribte for grad var - tensor_dist_attr = TensorDistributedAttribute() - process_mesh = dist_context.get_tensor_dist_attr_for_program( - forward_var).process_mesh - dims_mapping = dist_context.get_tensor_dist_attr_for_program( - forward_var).dims_mapping - tensor_dist_attr.dims_mapping = dims_mapping - tensor_dist_attr.process_mesh = process_mesh - dist_context.set_tensor_dist_attr_for_program(grad_var, - tensor_dist_attr) - - op_dist_attr = OperatorDistributedAttribute() - op_dist_attr.process_mesh = process_mesh - op_dist_attr.set_output_dims_mapping(grad_var.name, dims_mapping) - dist_context.set_op_dist_attr_for_program(ops[idx], op_dist_attr) - continue - - # complete the annotation of grad op (xxx_grad op or sum op) - # xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id - grad_op = ops[idx] - if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id: - # TODO support the case where one forward op corresponding to multiple xxx_grad op - forward_op = _get_op_by_id( - ops[:first_backward_op_idx], - dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()]) - assert forward_op is not None - - # op dist attr - forward_op_dist_attr = dist_context.get_op_dist_attr_for_program( - forward_op) - forward_op_process_mesh = forward_op_dist_attr.process_mesh - grad_op_dist_attr = OperatorDistributedAttribute() - grad_op_dist_attr.process_mesh = forward_op_process_mesh - - # var - for input_name in grad_op.input_arg_names: - input_var = vars[input_name] - ref_dims_mapping = None - if "@GRAD" in input_name: - forward_name = _get_forward_varname_from_grad_varname( - input_name) - ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping( - forward_name) - else: - if forward_op_dist_attr.get_input_dims_mapping(input_name): - ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping( + # complete the annotation of grad op (xxx_grad op or sum op) + # xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id + grad_op = ops[idx] + if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id: + # TODO support the case where one forward op corresponding to multiple xxx_grad op + forward_op = _get_op_by_id( + ops[:first_backward_op_idx], + dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()]) + assert forward_op is not None + + # op dist attr + forward_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( + forward_op) + forward_op_process_mesh = forward_op_dist_attr.process_mesh + grad_op_dist_attr = OperatorDistributedAttribute() + grad_op_dist_attr.process_mesh = forward_op_process_mesh + + # var + for input_name in grad_op.input_arg_names: + input_var = vars[input_name] + ref_dims_mapping = None + if "@GRAD" in input_name: + forward_name = _get_forward_varname_from_grad_varname( input_name) - else: ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping( - input_name) - - assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format( - input_var.name) - grad_op_dist_attr.set_input_dims_mapping(input_name, - ref_dims_mapping) - - for output_name in grad_op.desc.output_names(): - assert len(grad_op.desc.output(output_name)) in [0, 1] - if _is_grad_var_name(output_name): - input_name = _get_forward_varname_from_grad_varname( - output_name) - else: - assert grad_op.type in [ - "cast", "c_identity", "c_allreduce_sum" - ] - input_name = "X" - assert input_name in forward_op.desc.input_names( - ), "var [{}] in op [{}]'s output but could not find [{}] in its forward op".format( - output_name, grad_op.type, input_name) - if len(grad_op.desc.output(output_name)) == 1: - # tensor dist attr - output_var = vars[grad_op.desc.output(output_name)[0]] - forward_name = _get_forward_varname_from_grad_varname( - output_var.name) - ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping( - forward_name) - - output_var_dist_attr = TensorDistributedAttribute() - output_var_dist_attr.dims_mapping = ref_dims_mapping - output_var_dist_attr.process_mesh = forward_op_process_mesh - dist_context.set_tensor_dist_attr_for_program( - output_var, output_var_dist_attr) - - grad_op_dist_attr.set_output_dims_mapping(output_var.name, - ref_dims_mapping) - - dist_context.set_op_dist_attr_for_program(grad_op, - grad_op_dist_attr) - - # only sum op for merge mutiple version grad has no a corresponding mapping in grad_op_id_to_op_id - else: - assert grad_op.type == "sum", "got unexpect op [{}]".format( - str(grad_op.type)) - assert all(map(_is_grad_var_name, grad_op.input_arg_names)) - assert len(grad_op.output_arg_names) == 1 - - ref_forward_var_name = _get_forward_varname_from_grad_varname( - grad_op.output_arg_names[0]) - forward_var = vars[ref_forward_var_name] - ref_forward_var_dims_mapping = dist_context.get_tensor_dist_attr_for_program( - forward_var).dims_mapping - ref_forward_var_process_mesh = dist_context.get_tensor_dist_attr_for_program( - forward_var).process_mesh - - # output - tensor_dist_attr = TensorDistributedAttribute() - tensor_dist_attr.dims_mapping = ref_forward_var_dims_mapping - tensor_dist_attr.process_mesh = ref_forward_var_process_mesh - dist_context.set_tensor_dist_attr_for_program( - vars[grad_op.output_arg_names[0]], tensor_dist_attr) - - # op - grad_op_dist_attr = OperatorDistributedAttribute() - grad_op_dist_attr.process_mesh = ref_forward_var_process_mesh - for var_name in grad_op.input_arg_names: - assert _get_forward_varname_from_grad_varname( - var_name) == ref_forward_var_name - grad_op_dist_attr.set_input_dims_mapping( - var_name, ref_forward_var_dims_mapping) - - grad_op_dist_attr.set_output_dims_mapping( - grad_op.output_arg_names[0], ref_forward_var_dims_mapping) - dist_context.set_op_dist_attr_for_program(grad_op, - grad_op_dist_attr) - - -def complete_update_annotation(auto_parallel_main_prog, dist_context): - """Complete the annotation of vars and ops in the update phase for parallel program.""" - - if dist_context is None: - dist_context = get_default_distributed_context() - - ops = list(auto_parallel_main_prog.global_block().ops) - vars = auto_parallel_main_prog.global_block().vars - learning_rate_completed = False - - for idx in range(len(ops)): - - # complete the annotation of the optimizer op. - # TODO to add attribute for moment var - op = ops[idx] - if int(op.attr('op_role')) == int(OpRole.Optimize): - if op.type == "clip_by_norm": - - param_grad = vars[op.input("X")[0]] - param_grad_dist_attr = dist_context.get_tensor_dist_attr_for_program( - param_grad) - assert param_grad_dist_attr is not None - ref_process_mesh = param_grad_dist_attr.process_mesh - ref_dims_mapping = param_grad_dist_attr.dims_mapping - - out = vars[op.output("Out")[0]] - out_dist_attr = TensorDistributedAttribute() - out_dist_attr.process_mesh = ref_process_mesh - out_dist_attr.dims_mapping = ref_dims_mapping - dist_context.set_tensor_dist_attr_for_program(out, - out_dist_attr) + forward_name) + else: + if forward_op_dist_attr.get_input_dims_mapping( + input_name): + ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping( + input_name) + else: + ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping( + input_name) + + assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format( + input_var.name) + grad_op_dist_attr.set_input_dims_mapping(input_name, + ref_dims_mapping) - op_dist_attr = OperatorDistributedAttribute() - op_dist_attr.process_mesh = ref_process_mesh - op_dist_attr.set_input_dist_attr(param_grad.name, - param_grad_dist_attr) - op_dist_attr.set_output_dist_attr(out.name, out_dist_attr) - dist_context.set_op_dist_attr_for_program(op, op_dist_attr) - - if "Grad" in op.input_names and "Param" in ops[idx].input_names: - assert len(op.input( - "Param")) == 1, "Only support one-to-one now." - assert len(op.input( - "Grad")) == 1, "Only support one-to-one now." - param = vars[op.input("Param")[0]] - grad_var = vars[op.input("Grad")[0]] - - param_dist_attr = dist_context.get_tensor_dist_attr_for_program( - param) - assert param_dist_attr is not None - ref_process_mesh = dist_context.get_tensor_dist_attr_for_program( - param).process_mesh - assert ref_process_mesh is not None - ref_dims_mapping = dist_context.get_tensor_dist_attr_for_program( - param).dims_mapping - assert ref_dims_mapping is not None - op_dist_attr = OperatorDistributedAttribute() - op_dist_attr.process_mesh = ref_process_mesh - op_dist_attr.set_input_dims_mapping(grad_var.name, - ref_dims_mapping) - op_dist_attr.set_input_dims_mapping(param.name, - ref_dims_mapping) - op_dist_attr.set_output_dims_mapping(param.name, - ref_dims_mapping) - learning_var = vars[op.input("LearningRate")[0]] - op_dist_attr.set_input_dims_mapping(learning_var.name, [-1]) - op_dist_attr.set_output_dims_mapping(learning_var.name, [-1]) - - if not learning_rate_completed: - learning_rate_completed = True - var_dist_attr = TensorDistributedAttribute() - var_dist_attr.process_mesh = ref_process_mesh - var_dist_attr.dims_mapping = [-1] - dist_context.set_tensor_dist_attr_for_program(learning_var, - var_dist_attr) - - for input_name in op.desc.input_names(): - - if input_name in [ - 'Param', 'Grad', 'LearningRate', "SkipUpdate", - "Beta1Tensor", "Beta2Tensor", "EpsilonTensor", - "MasterParam" - ]: - continue + for output_name in grad_op.desc.output_names(): + assert len(grad_op.desc.output(output_name)) in [0, 1] + if _is_grad_var_name(output_name): + input_name = _get_forward_varname_from_grad_varname( + output_name) + else: + assert grad_op.type in [ + "cast", "c_identity", "c_allreduce_sum" + ] + input_name = "X" + assert input_name in forward_op.desc.input_names( + ), "var [{}] in op [{}]'s output but could not find [{}] in its forward op".format( + output_name, grad_op.type, input_name) + if len(grad_op.desc.output(output_name)) == 1: + # tensor dist attr + output_var = vars[grad_op.desc.output(output_name)[0]] + forward_name = _get_forward_varname_from_grad_varname( + output_var.name) + ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping( + forward_name) - assert len(op.desc.input(input_name)) == 1 - input_var = vars[op.desc.input(input_name)[0]] - input_var_attr = TensorDistributedAttribute() + output_var_dist_attr = TensorDistributedAttribute() + output_var_dist_attr.dims_mapping = ref_dims_mapping + output_var_dist_attr.process_mesh = forward_op_process_mesh + self._dist_context.set_tensor_dist_attr_for_program( + output_var, output_var_dist_attr) - if "Beta1Pow" in input_name or "Beta2Pow" in input_name: - input_var_attr.dims_mapping = [-1] - op_dist_attr.set_input_dims_mapping(input_var.name, - [-1]) - op_dist_attr.set_output_dims_mapping(input_var.name, - [-1]) - else: - assert "Moment" in input_name - input_var_attr.dims_mapping = ref_dims_mapping - op_dist_attr.set_input_dims_mapping(input_var.name, - ref_dims_mapping) - op_dist_attr.set_output_dims_mapping(input_var.name, - ref_dims_mapping) + grad_op_dist_attr.set_output_dims_mapping( + output_var.name, ref_dims_mapping) - input_var_attr.process_mesh = ref_process_mesh - dist_context.set_tensor_dist_attr_for_program( - input_var, input_var_attr) + self._dist_context.set_op_dist_attr_for_program( + grad_op, grad_op_dist_attr) - dist_context.set_op_dist_attr_for_program(op, op_dist_attr) - continue + # only sum op for merge mutiple version grad has no a corresponding mapping in grad_op_id_to_op_id + else: + assert grad_op.type == "sum", "got unexpect op [{}]".format( + str(grad_op.type)) + assert all(map(_is_grad_var_name, grad_op.input_arg_names)) + assert len(grad_op.output_arg_names) == 1 + + ref_forward_var_name = _get_forward_varname_from_grad_varname( + grad_op.output_arg_names[0]) + forward_var = vars[ref_forward_var_name] + ref_forward_var_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program( + forward_var).dims_mapping + ref_forward_var_process_mesh = self._dist_context.get_tensor_dist_attr_for_program( + forward_var).process_mesh + + # output + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.dims_mapping = ref_forward_var_dims_mapping + tensor_dist_attr.process_mesh = ref_forward_var_process_mesh + self._dist_context.set_tensor_dist_attr_for_program( + vars[grad_op.output_arg_names[0]], tensor_dist_attr) + + # op + grad_op_dist_attr = OperatorDistributedAttribute() + grad_op_dist_attr.process_mesh = ref_forward_var_process_mesh + for var_name in grad_op.input_arg_names: + assert _get_forward_varname_from_grad_varname( + var_name) == ref_forward_var_name + grad_op_dist_attr.set_input_dims_mapping( + var_name, ref_forward_var_dims_mapping) + + grad_op_dist_attr.set_output_dims_mapping( + grad_op.output_arg_names[0], ref_forward_var_dims_mapping) + self._dist_context.set_op_dist_attr_for_program( + grad_op, grad_op_dist_attr) + + def complete_update_annotation(self, serial_main_program): + """Complete the annotation of vars and ops in the update phase for parallel program.""" + ops = list(serial_main_program.global_block().ops) + vars = serial_main_program.global_block().vars + learning_rate_completed = False + + for idx in range(len(ops)): + + # complete the annotation of the optimizer op. + # TODO to add attribute for moment var + op = ops[idx] + if int(op.attr('op_role')) == int(OpRole.Optimize): + + if "Grad" in op.input_names and "Param" in ops[idx].input_names: + assert len(op.input( + "Param")) == 1, "Only support one-to-one now." + assert len(op.input( + "Grad")) == 1, "Only support one-to-one now." + param = vars[op.input("Param")[0]] + grad_var = vars[op.input("Grad")[0]] + + param_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( + param) + assert param_dist_attr is not None + ref_process_mesh = self._dist_context.get_tensor_dist_attr_for_program( + param).process_mesh + assert ref_process_mesh is not None + ref_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program( + param).dims_mapping + assert ref_dims_mapping is not None + op_dist_attr = OperatorDistributedAttribute() + op_dist_attr.process_mesh = ref_process_mesh + op_dist_attr.set_input_dims_mapping(grad_var.name, + ref_dims_mapping) + op_dist_attr.set_input_dims_mapping(param.name, + ref_dims_mapping) + op_dist_attr.set_output_dims_mapping(param.name, + ref_dims_mapping) + learning_var = vars[op.input("LearningRate")[0]] + op_dist_attr.set_input_dims_mapping(learning_var.name, [-1]) + op_dist_attr.set_output_dims_mapping(learning_var.name, + [-1]) + + if not learning_rate_completed: + learning_rate_completed = True + var_dist_attr = TensorDistributedAttribute() + var_dist_attr.process_mesh = ref_process_mesh + var_dist_attr.dims_mapping = [-1] + self._dist_context.set_tensor_dist_attr_for_program( + learning_var, var_dist_attr) + + for input_name in op.desc.input_names(): + + if input_name in [ + 'Param', 'Grad', 'LearningRate', "SkipUpdate", + "Beta1Tensor", "Beta2Tensor", "EpsilonTensor", + "MasterParam" + ]: + continue + + assert len(op.desc.input(input_name)) == 1 + input_var = vars[op.desc.input(input_name)[0]] + input_var_attr = TensorDistributedAttribute() + + if "Beta1Pow" in input_name or "Beta2Pow" in input_name: + input_var_attr.dims_mapping = [-1] + op_dist_attr.set_input_dims_mapping(input_var.name, + [-1]) + op_dist_attr.set_output_dims_mapping(input_var.name, + [-1]) + else: + assert "Moment" in input_name + input_var_attr.dims_mapping = ref_dims_mapping + op_dist_attr.set_input_dims_mapping( + input_var.name, ref_dims_mapping) + op_dist_attr.set_output_dims_mapping( + input_var.name, ref_dims_mapping) + + input_var_attr.process_mesh = ref_process_mesh + self._dist_context.set_tensor_dist_attr_for_program( + input_var, input_var_attr) + + self._dist_context.set_op_dist_attr_for_program( + op, op_dist_attr) + continue diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index ad3a53ff17d76..e06811df88179 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -247,23 +247,23 @@ def get_op_dist_attr_for_graph(self, serial_op_node): # new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr) # self._dist_ops_for_graph[serial_op_node_id] = new_dist_op - # def get_dist_attr_for_graph(self, serial_node): - # if serial_node.is_var() and serial_node.var() is not None: - # serial_tensor_node_id = serial_node.id() - # dist_tensor = self._dist_tensors_for_graph.get( - # serial_tensor_node_id, None) - # if dist_tensor: - # return dist_tensor.dist_attr - # else: - # return None - # if serial_node.is_op() and serial_node.op() is not None: - # serial_op_node_id = serial_node.id() - # dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None) - # if dist_op: - # return dist_op.dist_attr - # else: - # return None - # return None + def get_dist_attr_for_graph(self, serial_node): + if serial_node.is_var() and serial_node.var() is not None: + serial_tensor_node_id = serial_node.id() + dist_tensor = self._dist_tensors_for_graph.get( + serial_tensor_node_id, None) + if dist_tensor: + return dist_tensor.dist_attr + else: + return None + if serial_node.is_op() and serial_node.op() is not None: + serial_op_node_id = serial_node.id() + dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None) + if dist_op: + return dist_op.dist_attr + else: + return None + return None def init_dist_attr_for_program(self): assert self._serial_program, \ diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index d6035d02953ac..43f5fa264790f 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -32,7 +32,7 @@ from .dist_context import DistributedContext from .dist_context import get_default_distributed_context from .dist_context import set_default_distributed_context -from .completion import complete_annotation, complete_backward_annotation, complete_update_annotation +from .completion import Completer from .partitioner import Partitioner from .process_group import get_all_process_groups from .process_group import get_process_group @@ -130,8 +130,8 @@ def _generate_backward(self, main_program, startup_program, loss, no_grad_set, callbacks, distop_context=self._dist_context.dist_op_context) - complete_backward_annotation( - main_program, dist_context=self._dist_context) + self._completer = Completer(self._dist_context) + self._completer.complete_backward_annotation(main_program) return params_grads @@ -142,8 +142,8 @@ def _apply_optimize(self, main_program, startup_program, params_grads): params_grads) # update completion - complete_update_annotation( - main_program, dist_context=self._dist_context) + self._completer = Completer(self._dist_context) + self._completer.complete_update_annotation(main_program) return optimize_ops @@ -179,8 +179,9 @@ def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False): # Annotation completion self._dist_context = DistributedContext() _logger.info("Start annotation dist attr.") - completed_main_program = complete_annotation(serial_main_program, - self._dist_context) + self._completer = Completer(self._dist_context) + completed_main_program = self._completer.complete_forward_annotation( + serial_main_program) else: completed_main_program = serial_main_program self._dist_context = copy.deepcopy(dist_context) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py index 05d71aca5db2c..bc4f1671f4e20 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py @@ -27,6 +27,7 @@ from paddle.fluid import layers from paddle.nn.layer.transformer import _convert_param_attr_to_list import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix @@ -154,10 +155,9 @@ def test_mlp_dp(self): dist_context = DistributedContext() train_program, start_program = mlp_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_dist_attr(complete_train_program, - # dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_mlp_mp(self): @@ -171,10 +171,9 @@ def test_mlp_mp(self): dist_context = DistributedContext() train_program, start_program = mlp_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_dist_attr(complete_train_program, - # dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_mlp_dp_mp(self): @@ -189,10 +188,9 @@ def test_mlp_dp_mp(self): dist_context = DistributedContext() train_program, start_program = mlp_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_dist_attr(complete_train_program, - # dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) self.assertTrue(dist_context.validate_dist_attr_for_program()) # def test_mlp_misc(self): @@ -212,8 +210,8 @@ def test_mlp_dp_mp(self): # train_program, start_program = mlp_pretrain_forward(train_program, # start_program) # # pdb.set_trace() - # complete_train_program = auto.complete_annotation(train_program, - # dist_context) + # completer = Completer(dist_context) + # complete_train_program = auto.completer.complete_forward_annotation(train_program) # # print_program_with_dist_attr(complete_train_program, # # dist_context) # dist_context.finalize_distributed_attr_for_program( @@ -423,8 +421,9 @@ def test_attn_dp(self): dist_context = DistributedContext() train_program, start_program = attn_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) # print_program_with_dist_attr(complete_train_program, # dist_context) self.assertTrue(dist_context.validate_dist_attr_for_program()) @@ -440,10 +439,9 @@ def test_attn_mp(self): dist_context = DistributedContext() train_program, start_program = attn_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_dist_attr(complete_train_program, - # dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_attn_dp_mp(self): @@ -458,10 +456,9 @@ def test_attn_dp_mp(self): dist_context = DistributedContext() train_program, start_program = attn_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_dist_attr(complete_train_program, - # dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) self.assertTrue(dist_context.validate_dist_attr_for_program()) @@ -747,10 +744,9 @@ def test_decoder_dp(self): dist_context = DistributedContext() train_program, start_program = decoder_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_dist_attr(complete_train_program, - # dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_decoder_mp(self): @@ -764,10 +760,9 @@ def test_decoder_mp(self): dist_context = DistributedContext() train_program, start_program = decoder_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_dist_attr(complete_train_program, - # dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_decoder_dp_mp(self): @@ -782,10 +777,9 @@ def test_decoder_dp_mp(self): dist_context = DistributedContext() train_program, start_program = decoder_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_dist_attr(complete_train_program, - # dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) self.assertTrue(dist_context.validate_dist_attr_for_program()) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py index c2c1e63155c3a..1293a9644027d 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py @@ -31,6 +31,7 @@ from paddle.distributed.fleet import fleet import paddle.static as static import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.dist_context import DistributedContext @@ -817,10 +818,9 @@ def test_gpt_dp(self): dist_context = DistributedContext() train_program, start_program = gpt_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_dist_attr(complete_train_program, - # dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_gpt_mp(self): @@ -834,10 +834,9 @@ def test_gpt_mp(self): dist_context = DistributedContext() train_program, start_program = gpt_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_dist_attr(complete_train_program, - # dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_gpt_dp_mp(self): @@ -852,10 +851,9 @@ def test_gpt_dp_mp(self): dist_context = DistributedContext() train_program, start_program = gpt_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_dist_attr(complete_train_program, - # dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) self.assertTrue(dist_context.validate_dist_attr_for_program()) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py index 83254de61298b..fd19a5bd8b866 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py @@ -23,6 +23,7 @@ import paddle.nn.functional as F import paddle.utils as utils import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed import fleet from paddle.distributed.auto_parallel.partitioner import Partitioner @@ -154,8 +155,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): parallelizer._dist_context = dist_context # serial forward & backward completion - complete_train_program = auto.complete_annotation(train_program, - dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) params_grads = parallelizer._generate_backward( complete_train_program, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py index b21cbb5ae78bc..27de9f325063b 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py @@ -18,6 +18,7 @@ import paddle from paddle.fluid import core import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed import fleet from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.partitioner import Partitioner @@ -42,8 +43,9 @@ def get_dist_prog(train_program, parallelizer._dist_context = dist_context # serial forward & backward completion - complete_train_program = auto.complete_annotation( - train_program, dist_context + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program ) if complete_train_program is None else complete_train_program # parallelizer._apply_serial_forward_pass(complete_train_program, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py index 3a28595c833e0..9d4de771076cd 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py @@ -36,6 +36,7 @@ from paddle.distributed import fleet import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.partitioner import Partitioner @@ -433,6 +434,12 @@ def forward(self, input): out = F.gelu(out, approximate=True) out = self.linear1(out) + auto.shard_tensor( + out, + dist_attr={ + "process_mesh": _global_process_mesh[1], + "dims_mapping": [0, -1] + }) out = self.linear2(out) out = F.gelu(out, approximate=True) out = self.linear3(out) @@ -476,8 +483,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): parallelizer._dist_context = dist_context # auto completion - complete_train_program = auto.complete_annotation(train_program, - dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) params_grads = parallelizer._generate_backward( complete_train_program, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py index 21cf8a904b690..deff2144411fc 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py @@ -28,6 +28,7 @@ from paddle.fluid import layers from paddle.nn.layer.transformer import _convert_param_attr_to_list import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix @@ -49,8 +50,9 @@ def get_programs(annotated_func): global _global_process_mesh dist_context.process_mesh = _global_process_mesh train_program, start_program = annotated_func(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) rank_id = 3 dist_strategy = fleet.DistributedStrategy() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py index dc2ad1d900f52..01e62d886e2b7 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py @@ -31,6 +31,7 @@ from paddle.distributed import fleet import paddle.static as static import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.dist_context import DistributedContext @@ -881,8 +882,9 @@ def test_gpt_dp_mp(self): dist_context.process_mesh = _global_process_mesh train_program, startup_program, loss = gpt_pretrain_forward( train_program, startup_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) # serial backward pass params_grads = parallelizer._generate_backward( @@ -913,8 +915,9 @@ def test_gpt_dp_mp(self): "w") as fw: fw.write(str(auto_parallel_startup_prog)) # with open("./test_auto_parallel_partitioner_main_completed.txt", "w") as fw: - # from paddle.distributed.auto_parallel.completion import complete_backward_annotation - # complete_backward_annotation(auto_parallel_main_prog) + # from paddle.distributed.auto_parallel.completion import Completer + # completer = Completer() + # completer.complete_forward_annotation(auto_parallel_main_prog) # fw.write(str(auto_parallel_main_prog)) nrank = 4 # col parallel diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py index 614b996d26521..b234e25823f4b 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py @@ -22,6 +22,7 @@ import paddle.nn.functional as F import paddle.utils as utils import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed import fleet from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer @@ -152,8 +153,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): parallelizer._dist_context = dist_context # serial forward & backward completion - complete_train_program = auto.complete_annotation(train_program, - dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) params_grads = parallelizer._generate_backward( complete_train_program, @@ -299,7 +301,6 @@ def test_mlp_pp(self): for key in list(_g_process_group_map.keys()): del _g_process_group_map[key] reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) - # print_program_with_dist_attr(dist_main_prog, dist_context) # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py index cfbb7653fad8e..40847a769033a 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py @@ -22,6 +22,7 @@ import paddle.nn.functional as F import paddle.utils as utils import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed import fleet from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer @@ -116,8 +117,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): parallelizer._dist_context = dist_context # serial forward & backward completion - complete_train_program = auto.complete_annotation(train_program, - dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) params_grads = parallelizer._generate_backward( complete_train_program, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py index 272c1c212f08e..869bcd4c7ab32 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py @@ -22,6 +22,7 @@ import paddle.nn.functional as F import paddle.utils as utils import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed import fleet from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer @@ -132,8 +133,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): parallelizer._dist_context = dist_context # serial forward & backward completion - complete_train_program = auto.complete_annotation(train_program, - dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) params_grads = parallelizer._generate_backward( complete_train_program, @@ -263,8 +265,9 @@ def test_allgather(self): dist_context = DistributedContext() dist_strategy = fleet.DistributedStrategy() partitioner = Partitioner(dist_context, rank_id) - complete_train_program = auto.complete_annotation(train_program, - dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition( complete_train_program, startup_program, []) reshard(partitioned_main_prog, partitioned_startup_prog, rank_id, From a20779d7a803b6d07c0190c6c9839bc13503c9f0 Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Fri, 21 Jan 2022 01:54:37 +0000 Subject: [PATCH 11/11] Fix the bug of auto_searcher unittest --- .../fluid/tests/unittests/test_auto_parallel_searcher.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py index ed64fa0630fa1..78ad64b1dd852 100755 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py @@ -154,7 +154,7 @@ def test_update(self): ops = train_program.global_block().ops vars = train_program.global_block().vars from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container - from paddle.distributed.auto_parallel.completion import is_elementwise_like_op + from paddle.distributed.auto_parallel.operators.common import is_elementwise_op from paddle.distributed.auto_parallel.dist_op import DistributedOperator for op in ops: @@ -163,7 +163,7 @@ def test_update(self): if dist_op_impl_container is None: op_dist_attr = dist_context.get_op_dist_attr_for_program(op) dist_op = DistributedOperator(op, op_dist_attr) - if is_elementwise_like_op(op.type): + if is_elementwise_op(op.type): changed = update_op_dims_mapping_by_elementwise_like_dist_impl( dist_op) self.assertFalse(changed)