From a97b507e233f80a65193477e16d9677bc2a115ce Mon Sep 17 00:00:00 2001 From: Yichen Zhang <32740647+pkuzyc@users.noreply.github.com> Date: Mon, 14 Aug 2023 10:11:59 +0800 Subject: [PATCH] [Semi-Auto] Add reshape spmd rule (#55177) * add reshape spmd rule * add unit test for reshape spmd rule * bug fix * replace the print_info function with to_string * fix typo * bug fix * add handling for "0" in target shape * remove the part of computing size in dim_trans.cc --- .../auto_parallel/spmd_rules/common.h | 6 +- .../auto_parallel/spmd_rules/dim_trans.cc | 355 ++++++++++++++++++ .../auto_parallel/spmd_rules/dim_trans.h | 160 ++++++++ .../spmd_rules/reshape_spmd_rule.cc | 206 ++++++++++ .../spmd_rules/reshape_spmd_rule.h | 40 ++ .../auto_parallel/spmd_rules/rules.h | 4 + test/auto_parallel/spmd_rules/CMakeLists.txt | 1 + .../spmd_rules/test_reshape_rule.py | 219 +++++++++++ 8 files changed, 988 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.cc create mode 100644 paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.h create mode 100644 paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.cc create mode 100644 paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h create mode 100644 test/auto_parallel/spmd_rules/test_reshape_rule.py diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h index f5a49ab0a9f18..26c421eb27e23 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h @@ -125,14 +125,14 @@ std::string GetBroadcastAxes(const int64_t& tenosr_ndim, TensorDistAttr ReplicatedOnMesh(const TensorDistAttr& src_dist_attr); // Check whether the given DistTensorSpec objects are valid. For each -// DistTensorSpec, the rank of its dimsmapping must be equal to the rank of its +// DistTensorSpec, the rank of its dims mapping must be equal to the rank of its // corresponding tensor shape. the parameter op_name is used for logging error // message. void VerifySpecs(const std::vector& specs, const std::string& op_name); -// Get dimsmapping for the given tensors. Return the pair of each -// tensor's einsum notation and the corresponding dimsmapping. +// Get dims mapping for the given tensors. Return the pair of each +// tensor's einsum notation and the corresponding dims mapping. std::vector>> GetAxesDimsMappingPair(const std::vector& tensor_axes, const std::vector& specs); diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.cc new file mode 100644 index 0000000000000..993793a7d64ec --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.cc @@ -0,0 +1,355 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.h" +#include +#include +#include +#include +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" +#include "paddle/phi/core/enforce.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { + +static std::vector all_dim_trans; + +DimTrans::DimTrans(Type type) : type_(type) {} + +DimTrans::~DimTrans() {} + +DimTrans::Type DimTrans::type() const { return type_; } + +void DimTrans::set_type(Type type) { type_ = type; } + +std::string DimTrans::to_string() { return std::string(""); } + +InputDim::InputDim() : DimTrans(DimTrans::Type::INPUTDIM) { + input_dim_ = -1; + all_dim_trans.emplace_back(this); +} + +InputDim::InputDim(int64_t dim) : DimTrans(DimTrans::Type::INPUTDIM) { + input_dim_ = dim; + all_dim_trans.emplace_back(this); +} + +InputDim::~InputDim() {} + +int64_t InputDim::input_dim() const { return input_dim_; } + +void InputDim::set_input_dim(int64_t dim) { input_dim_ = dim; } + +std::string InputDim::to_string() { + return ("InputDim(" + std::to_string(input_dim_) + ")"); +} + +Singleton::Singleton() : DimTrans(DimTrans::Type::SINGLETON) { + all_dim_trans.emplace_back(this); +} + +std::string Singleton::to_string() { return "Singleton()"; } + +Flatten::Flatten() : DimTrans(DimTrans::Type::FLATTEN) { + all_dim_trans.emplace_back(this); +} + +Flatten::Flatten(const std::vector& dims) + : DimTrans(DimTrans::Type::FLATTEN) { + input_dims_ = dims; + all_dim_trans.emplace_back(this); +} + +Flatten::~Flatten() { + input_dims_.assign(input_dims_.size(), nullptr); + std::vector().swap(input_dims_); +} + +const std::vector& Flatten::inputs() const { return input_dims_; } + +void Flatten::set_inputs(const std::vector& dims) { + input_dims_.assign(dims.begin(), dims.end()); +} + +std::string Flatten::to_string() { + std::string ret_str("Flatten("); + for (int64_t i = 0, n = input_dims_.size(); i < n; ++i) { + ret_str += input_dims_[i]->to_string(); + if (i < n - 1) { + ret_str += ","; + } + } + return ret_str + ")"; +} + +Split::Split() : DimTrans(DimTrans::Type::SPLIT) { + input_dim_trans_ = nullptr; + all_dim_trans.emplace_back(this); +} + +Split::Split(DimTrans* dim, const std::vector& shape, int64_t id) + : DimTrans(DimTrans::Type::SPLIT) { + input_dim_trans_ = dim; + split_id_ = id; + splitted_shape_.assign(shape.begin(), shape.end()); + all_dim_trans.emplace_back(this); +} + +Split::~Split() { + input_dim_trans_ = nullptr; + std::vector().swap(splitted_shape_); +} + +DimTrans* Split::input() const { return input_dim_trans_; } + +void Split::set_input(DimTrans* dim) { input_dim_trans_ = dim; } + +int64_t Split::split_id() const { return split_id_; } + +int64_t Split::local_splitted_shape_value() { + return splitted_shape_[split_id_]; +} + +std::string Split::to_string() { + std::string ret_str("Split("); + ret_str += input_dim_trans_->to_string() + ", ("; + for (int64_t i = 0, n = splitted_shape_.size(); i < n; ++i) { + ret_str += std::to_string(splitted_shape_[i]); + if (i < n - 1) { + ret_str += ","; + } + } + return ret_str + "), " + std::to_string(split_id_) + ")"; +} + +DimTrans* make_flatten(const std::vector& dims) { + DimTrans* ptr = nullptr; + if (dims.size() == 0) { + ptr = new Singleton(); + } else if (dims.size() == 1) { + ptr = dims[0]; + } else { + ptr = new Flatten(dims); + } + return ptr; +} + +DimTrans* make_split(DimTrans* dim, + const std::vector& shape, + int64_t id) { + assert(shape.size() > 0); + DimTrans* ptr = nullptr; + if (shape.size() == 1) { + assert(id == 0); + ptr = dim; + } else if (shape[id] == 1) { + ptr = new Singleton(); + } else { + // new shape that remove 1 + std::vector new_shape; + // map between from idx in shape to new_shape + std::vector idx_map(shape.size(), -1); + for (int64_t i = 0, n = shape.size(); i < n; ++i) { + if (shape[id] != 1) { + idx_map[i] = new_shape.size(); + new_shape.emplace_back(shape[i]); + } + } + ptr = new Split(dim, new_shape, idx_map[id]); + } + return ptr; +} + +void CleanUp() { + for (int64_t i = 0, n = all_dim_trans.size(); i < n; i++) { + if (all_dim_trans[i]) { + delete all_dim_trans[i]; + all_dim_trans[i] = nullptr; + } + } + std::vector().swap(all_dim_trans); +} + +// Given a `dim_trans` of an output axis, get the input axis +// whose dim mapping should be propogated to it. +// If the returned input axis is none, the output axis's +// dim mapping should be set to -1 (replicated). For an axis +// that is flattened from input axes, return the leftmost +// flattened input axis. For the split transformation, +// only the leftmost split axis in output will return its input. +DimTrans* GetDimTrans(DimTrans* dim_trans, + std::vector>* shardable, + std::set* seen_dims, + const std::vector& input_shape, + const std::vector& mesh_shape, + const std::vector& input_dims_mapping, + const std::set& sharded_input_dims) { + DimTrans::Type type = dim_trans->type(); + DimTrans* ret_dim_trans = nullptr; + + if (type == DimTrans::Type::INPUTDIM) { + InputDim* inputdim = dynamic_cast(dim_trans); + int64_t dim = inputdim->input_dim(); + seen_dims->insert(dim); + + if (sharded_input_dims.count(dim) > 0) { + ret_dim_trans = dim_trans; + } + } else if (type == DimTrans::Type::FLATTEN) { + Flatten* flatten = dynamic_cast(dim_trans); + const std::vector& inputs = flatten->inputs(); + int64_t nmesh = (*shardable)[0].size(); + for (int64_t i = 1, n = inputs.size(); i < n; i++) { + DimTrans* input = inputs[i]; + if (input->type() == DimTrans::Type::INPUTDIM) { + (*shardable)[i].assign(nmesh, false); + } + + GetDimTrans(input, + shardable, + seen_dims, + input_shape, + mesh_shape, + input_dims_mapping, + sharded_input_dims); + } + + DimTrans* dim0 = inputs[0]; + if (dim0->type() == DimTrans::Type::INPUTDIM) { + InputDim* inputdim = dynamic_cast(dim0); + if (sharded_input_dims.count(inputdim->input_dim()) > 0) { + ret_dim_trans = dim0; + } + } + } else if (type == DimTrans::Type::SPLIT) { + Split* split = dynamic_cast(dim_trans); + DimTrans* dim = GetDimTrans(split->input(), + shardable, + seen_dims, + input_shape, + mesh_shape, + input_dims_mapping, + sharded_input_dims); + int64_t ret_size = split->local_splitted_shape_value(); + + if (split->split_id() == 0) { + if (dim != nullptr) { + PADDLE_ENFORCE_EQ(dim->type(), + DimTrans::Type::INPUTDIM, + phi::errors::InvalidArgument( + "The returned dim_trans must be INPUTDIM.")); + InputDim* inputdim = dynamic_cast(dim); + int64_t nmesh = mesh_shape.size(); + int64_t input_axis = inputdim->input_dim(); + + // Check whether the sharded dim can be sharded on + // each mesh dimension. The dimension should be + // divisible by the mesh size that it is sharded on + for (int64_t imesh = 0; imesh < nmesh; imesh++) { + (*shardable)[input_axis][imesh] = (ret_size % mesh_shape[imesh] == 0); + } + } + ret_dim_trans = dim; + } + } else if (type == DimTrans::Type::SINGLETON) { + ret_dim_trans = nullptr; + } + return ret_dim_trans; +} + +void GetUsedInputDim(DimTrans* dim_trans, std::set* seen_dims) { + if (dim_trans->type() == DimTrans::Type::INPUTDIM) { + InputDim* input = dynamic_cast(dim_trans); + seen_dims->insert(input->input_dim()); + } else if (dim_trans->type() == DimTrans::Type::FLATTEN) { + Flatten* flatten = dynamic_cast(dim_trans); + for (DimTrans* trans : flatten->inputs()) { + GetUsedInputDim(trans, seen_dims); + } + } else if (dim_trans->type() == DimTrans::Type::SPLIT) { + Split* split = dynamic_cast(dim_trans); + GetUsedInputDim(split->input(), seen_dims); + } else { + return; + } +} + +std::vector> InferFromDimTrans( + const DistTensorSpec& input_spec, const std::vector& dim_trans) { + const std::vector& input_shape = input_spec.shape(); + const std::vector& input_dims_mapping = input_spec.dims_mapping(); + const ProcessMesh& mesh = input_spec.dist_attr().process_mesh(); + const std::vector& mesh_shape = mesh.shape(); + + std::set sharded_input_dims; + for (int64_t i = 0, n = input_dims_mapping.size(); i < n; ++i) { + if (input_dims_mapping[i] > -1) { + sharded_input_dims.insert(i); + } + } + int64_t ndim = input_shape.size(); + int64_t nmesh = mesh_shape.size(); + std::vector> shardable(ndim, + std::vector(nmesh, true)); + + std::set seen_input_dims; + for (DimTrans* trans : dim_trans) { + GetUsedInputDim(trans, &seen_input_dims); + } + + for (int64_t idim = 0; idim < ndim; idim++) { + bool seen = seen_input_dims.count(idim); + if (!seen) { + shardable[idim].assign(nmesh, seen); + } + } + + // get the map from sharded input dimensions to output dimensions. + std::vector dim_map_src2tgt(ndim, -1); + for (int64_t i = 0, n = dim_trans.size(); i < n; i++) { + DimTrans* dim = GetDimTrans(dim_trans[i], + &shardable, + &seen_input_dims, + input_shape, + mesh_shape, + input_dims_mapping, + sharded_input_dims); + if (dim != nullptr && dim->type() == DimTrans::Type::INPUTDIM) { + InputDim* inputdim = dynamic_cast(dim); + dim_map_src2tgt[inputdim->input_dim()] = i; + } + } + + std::vector out_dims_mapping(dim_trans.size(), -1); + std::vector new_input_dims_mapping(input_dims_mapping); + + // set output dims mapping with corresponding input dimensions. + // if one input dimension is sharded on a unshardable mesh after + // splitting, we need to make it replicated. + for (int64_t i = 0; i < ndim; i++) { + int64_t mesh_dim = input_dims_mapping[i]; + if (mesh_dim > -1 && shardable[i][mesh_dim] && dim_map_src2tgt[i] > -1) { + out_dims_mapping[dim_map_src2tgt[i]] = input_dims_mapping[i]; + } else { + new_input_dims_mapping[i] = -1; + } + } + + return {new_input_dims_mapping, out_dims_mapping}; +} + +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.h new file mode 100644 index 0000000000000..f196a0266d5d4 --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.h @@ -0,0 +1,160 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { + +// This is a base class to describe how each dimension in output tensor +// is transformed from input tensor's axes. The transformation includes +// Flatten, Split, etc. A vector whose size equals to the +// output tensor's rank can be used to describe how the output shape is +// transformed from the input shape. Each element in vector +// describes the transformation of one output axis. For example, when +// a reshape operator reshapes a tensor from the shape of (6, 12, 48) +// to (72, 6, 8), this transfromation can be described as: +// [Flatten(Dim(0), Dim(1)), Split(Dim(2), (6,8), 0), Split(Dim(2), (6,8), 1)] +// meaning that dim0 in output is flattened from dim0 and dim1 in input, +// dim1 and dim2 in output are obtained by splitting dim2 in input, the +// splitted shape is (6, 8), dim1 referes to the first shape value in (6, 8) +// and dim2 referes to the second shape value in (6, 8). +class DimTrans { + public: + enum class Type { INPUTDIM, SINGLETON, FLATTEN, SPLIT }; + + DimTrans() = default; + + explicit DimTrans(Type type); + + virtual ~DimTrans(); + + Type type() const; + + void set_type(Type type); + + virtual std::string to_string(); + + private: + Type type_; +}; + +// InputDim indicates that the output dimention +// is obtained directed from one input dimension. +class InputDim : public DimTrans { + public: + InputDim(); + + explicit InputDim(int64_t dim); + + virtual ~InputDim(); + + int64_t input_dim() const; + + void set_input_dim(int64_t dim); + + std::string to_string() override; + + private: + int64_t input_dim_; +}; + +// Singleton indicates that the shape of the +// corresponding output dimension is 1 +class Singleton : public DimTrans { + public: + Singleton(); + std::string to_string() override; +}; + +// Flatten indicates that the output dimension +// is obtained from flattening input dimensions. +class Flatten : public DimTrans { + public: + Flatten(); + + explicit Flatten(const std::vector& dims); + + virtual ~Flatten(); + + const std::vector& inputs() const; + + void set_inputs(const std::vector& dims); + + std::string to_string() override; + + private: + std::vector input_dims_; +}; + +// Split indicates that the output dimension +// is obtained by splitting input dimension. +class Split : public DimTrans { + public: + Split(); + + Split(DimTrans* dim, const std::vector& shape, int64_t id); + + virtual ~Split(); + + DimTrans* input() const; + + void set_input(DimTrans* dim); + + int64_t split_id() const; + + // get the splitted shape value of the split_id_ dimension + int64_t local_splitted_shape_value(); + + std::string to_string() override; + + private: + DimTrans* input_dim_trans_; + std::vector splitted_shape_; + int64_t split_id_; +}; + +void CleanUp(); + +DimTrans* make_flatten(const std::vector& dims = {}); + +DimTrans* make_split(DimTrans* dim, + const std::vector& shape = {}, + int64_t id = 0); + +// Infer the dims mapping of the output tensor according to the transformation +// `dim_trans`. Returns the dims mapping of the input tensor (the input dims +// mapping may be changed for resharding) and output tensor. The inferring +// follows the rules: +// 1. For Singleton, i.e., the shape of this output axis is 1, its dim mapping +// is -1, indicating that the output axis is replicated. +// 2. For InputDim, i.e., the output axis is transformed directly from an input +// axis, set its dim mapping equals to the corresponding input axis. +// 3. For Flatten, i.e., the output axis is flattened from some input axes, it +// can be sharded only if the leftmost flattened axes is sharded. +// 4. For Split, i.e., the output axes is splited from a input axis, only the +// leftmost output split axis can be sharded when its shape can be divisible +// by the mesh dimension. +std::vector> InferFromDimTrans( + const DistTensorSpec& input_spec, const std::vector& dim_trans); + +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.cc new file mode 100644 index 0000000000000..0b64a4f00ecde --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.cc @@ -0,0 +1,206 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h" +#include +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { + +using phi::distributed::auto_parallel::str_join; + +// The target shape in reshape may contains a -1 dimension, +// this function is used to infer what the "-1" dimension is. +std::vector InferTargetShape(const std::vector& shape, + int64_t len) { + int64_t infer_idx = -1; + for (int64_t i = 0, n = shape.size(); i < n; i++) { + if (shape[i] == -1) { + PADDLE_ENFORCE_EQ( + infer_idx, + -1, + phi::errors::InvalidArgument( + "There can't be more than one -1 dimension in target shape.")); + infer_idx = i; + } + } + + int64_t product = std::accumulate( + shape.begin(), shape.end(), 1, std::multiplies()); + if (product > 0) { + PADDLE_ENFORCE_EQ( + product, + len, + phi::errors::InvalidArgument("The total size are not matched")); + return std::vector(shape); + } else { + std::vector new_shape(shape); + product = -product; + int64_t infer_size = len / product; + PADDLE_ENFORCE_EQ(len % infer_size, + 0, + phi::errors::InvalidArgument( + "The total is not diviable by infer_size")); + new_shape[infer_idx] = infer_size; + return new_shape; + } +} + +// Compute how each dimension in target shape +// is obtained from the input dimensions +std::vector MakeReshapeDimTrans( + const std::vector& src_shape, + const std::vector& tgt_shape) { + std::vector ret; + int64_t total_elem_num_src = std::accumulate( + src_shape.begin(), src_shape.end(), 1, std::multiplies()); + std::vector inferred_tgt_shape = + InferTargetShape(tgt_shape, total_elem_num_src); + + int64_t src_idx = 0, tgt_idx = 0; + int64_t s, t; + int64_t src_len, tgt_len; + src_len = src_shape.size(); + tgt_len = inferred_tgt_shape.size(); + while (src_idx < src_len || tgt_idx < tgt_len) { + std::vector src_dims, tgt_splitted_shape; + if (src_idx >= src_len) { + s = 1; + } else { + s = src_shape[src_idx]; + src_dims.emplace_back(src_idx); + src_idx++; + } + if (tgt_idx >= tgt_len) { + t = 1; + } else { + t = tgt_shape[tgt_idx]; + tgt_splitted_shape.emplace_back(t); + tgt_idx++; + } + + // deal with the singleton case + if (s == 1 && t != 1) { + // case [1] [a] + tgt_idx--; + tgt_splitted_shape.clear(); + } else if (s != 1 && t == 1) { + src_idx--; + src_dims.clear(); + } else { + while (s != t) { + if (s < t) { + src_dims.emplace_back(src_idx); + s *= src_shape[src_idx]; + src_idx++; + } else { + tgt_splitted_shape.emplace_back(inferred_tgt_shape[tgt_idx]); + t *= inferred_tgt_shape[tgt_idx]; + tgt_idx++; + } + } + } + + if (tgt_splitted_shape.size() > 0) { + std::vector input_dims; + for (int64_t i = 0, n = src_dims.size(); i < n; i++) { + int64_t in_dim = src_dims[i]; + if (src_shape[in_dim] > 1) { + input_dims.emplace_back(new InputDim(in_dim)); + } + } + DimTrans* flatten = make_flatten(input_dims); + + for (int64_t i = 0, n = tgt_splitted_shape.size(); i < n; i++) { + ret.emplace_back(make_split(flatten, tgt_splitted_shape, i)); + } + } + } + return ret; +} + +std::pair, std::vector> +paddle::distributed::auto_parallel::ReshapeSPMDRule::InferForward( + const std::vector& input_specs, + const paddle::framework::AttributeMap& attrs) { + // step0: Verify Input Args Based on Reshape Logic + int64_t ninputs = input_specs.size(); + PADDLE_ENFORCE_EQ( + ninputs, + 1, + phi::errors::InvalidArgument("The size of InputSpec in reshape must " + "be equal to 1, but got [%d].", + ninputs)); + VerifySpecs(input_specs, "reshape"); + + // step1: build the transformation from + // original shape to target shape + std::vector src_shape = input_specs[0].shape(); + std::vector tgt_shape = + ExtractAttr>("shape", attrs); + + // handle the '0' values in target shape, '0' indicates + // that the target shape is equal to the source shape + for (int64_t i = 0, n = tgt_shape.size(); i < n; i++) { + if (tgt_shape[i] == 0) { + tgt_shape[i] = src_shape[i]; + } + } + + std::vector trans = MakeReshapeDimTrans(src_shape, tgt_shape); + + // step2: infer the dims mapping of input (if reshard is + // needed) and output from the dimension transformation. + std::vector> dims_mapping_vec = + InferFromDimTrans(input_specs[0], trans); + + // step3: update the dist attributes of input + // and output with the inferred dims mapping + TensorDistAttr new_input_dist_attr(input_specs[0].dist_attr()); + new_input_dist_attr.set_dims_mapping(dims_mapping_vec[0]); + TensorDistAttr output_dist_attr(input_specs[0].dist_attr()); + output_dist_attr.set_dims_mapping(dims_mapping_vec[1]); + + VLOG(4) << "Reshape: input_shape: [" << str_join(src_shape) + << "] output_shape: [" << str_join(tgt_shape) << "]"; + VLOG(4) << "Transformation from input to output:"; + for (int64_t i = 0, n = trans.size(); i < n; i++) { + DimTrans* t = trans[i]; + VLOG(4) << "\tOutput axis " << i << ": " << t->to_string(); + } + VLOG(4) << "input_dims_mapping: [" << str_join(dims_mapping_vec[0]) + << "] output_dims_mapping: [" << str_join(dims_mapping_vec[1]) + << "]\n\n"; + + CleanUp(); + + return {{new_input_dist_attr}, {output_dist_attr}}; +} + +std::pair, std::vector> +paddle::distributed::auto_parallel::ReshapeSPMDRule::InferBackward( + const std::vector& output_specs, + const paddle::framework::AttributeMap& attrs) { + PADDLE_THROW(phi::errors::Unimplemented( + "InferBackward of ReductionSPMDRule is NOT implemented yet.")); + + return {}; +} + +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h new file mode 100644 index 0000000000000..63b9a5a6f038a --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { + +class ReshapeSPMDRule : public SPMDRuleBase { + public: + std::pair, std::vector> + InferForward(const std::vector& input_specs, + const paddle::framework::AttributeMap& attrs) override; + + std::pair, std::vector> + InferBackward(const std::vector& output_specs, + const paddle::framework::AttributeMap& attrs) override; +}; +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h index 713a52770926d..cf4046950964a 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h @@ -22,6 +22,7 @@ #include "paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h" +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/transpose_spmd_rule.h" @@ -159,6 +160,9 @@ REGISTER_SPMD_RULE(split_with_num, SplitSPMDRule); // transpose rule REGISTER_SPMD_RULE(transpose, TransposeSPMDRule); +// reshape rule +REGISTER_SPMD_RULE(reshape, ReshapeSPMDRule); + } // namespace auto_parallel } // namespace distributed } // namespace paddle diff --git a/test/auto_parallel/spmd_rules/CMakeLists.txt b/test/auto_parallel/spmd_rules/CMakeLists.txt index 43afd9aed75e7..c981aee6f83e1 100644 --- a/test/auto_parallel/spmd_rules/CMakeLists.txt +++ b/test/auto_parallel/spmd_rules/CMakeLists.txt @@ -10,6 +10,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_matmul_rule MODULES test_softmax_rule) py_test_modules(test_split_rule MODULES test_split_rule) py_test_modules(test_transpose_rule MODULES test_transpose_rule) + py_test_modules(test_reshape_rule MODULES test_reshape_rule) # End of unittests WITH single card WITHOUT timeout endif() diff --git a/test/auto_parallel/spmd_rules/test_reshape_rule.py b/test/auto_parallel/spmd_rules/test_reshape_rule.py new file mode 100644 index 0000000000000..8999bc3e34c38 --- /dev/null +++ b/test/auto_parallel/spmd_rules/test_reshape_rule.py @@ -0,0 +1,219 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from paddle.distributed.auto_parallel.static.completion import get_spmd_rule +from paddle.distributed.auto_parallel.static.dist_attribute import ( + DistTensorSpec, + TensorDistAttr, +) +from paddle.distributed.fleet import auto + + +class TestReshapeSPMDRule(unittest.TestCase): + def setUp(self): + self.rule = get_spmd_rule("reshape") + + x_shape = [6, 12, 48, 24] + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]]) + + x_tensor_dist_attr = TensorDistAttr() + x_tensor_dist_attr.dims_mapping = [-1, -1] + x_tensor_dist_attr.process_mesh = process_mesh + self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) + + self.attrs = {"shape": [1, 72, 48, 4, 6]} + + def test_reshape_infer_forward(self): + # shape: [6, 12, 48, 24] --> [1, 72, 48, 4, 6] + # dims_mapping: [0, -1, 1, -1] --> [0, -1, 1, -1] [-1, 0, 1, -1, -1] + self.x_dist_tensor_spec.set_dims_mapping([0, -1, 1, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, -1, 1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1, -1, -1] + ) + + # shape: [6, 12, 48, 24] --> [1, 72, 48, 4, 6] + # dims_mapping: [-1, 0, -1, 1] --> [-1, -1, -1, -1] [-1, -1, -1, -1, -1] + self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1, -1] + ) + + # shape: [6, 12, 48, 24] --> [1, 72, 48, 4, 6] + # dims_mapping: [1, -1, -1, 0] --> [1, -1, -1, 0] [-1, 1, -1, 0, -1] + self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [1, -1, -1, 0] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, 1, -1, 0, -1] + ) + + # shape: [6, 12, 48, 24] --> [3, 24, 6, 8, 24] + # dims_mapping: [0, 1, -1, -1] --> [-1, -1, -1, -1] [-1, -1, -1, -1, -1] + self.attrs["shape"] = [3, 24, 6, 8, 24] + self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1]) + + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1, -1] + ) + + # shape: [6, 12, 48, 24] --> [3, 24, 6, 8, 24] + # dims_mapping: [1, -1, -1, 0] --> [1, -1, -1, 0] [1, -1, -1, -1, 0] + self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0]) + + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [1, -1, -1, 0] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, -1, 0] + ) + + # shape: [6, 12, 48, 24] --> [3, 24, 6, -1, 24] + # dims_mapping: [-1, -1, 0, 1] --> [-1, -1, 0, 1], [-1, -1, 0, -1, 1] + self.attrs["shape"] = [3, 24, 6, -1, 24] + self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0, 1]) + + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0, 1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0, -1, 1] + ) + + # shape: [6, 12, 48, 24] --> [1, 72, 0, 4, 6] + # dims_mapping: [1, -1, -1, 0] --> [1, -1, -1, 0] [-1, 1, -1, 0, -1] + self.attrs["shape"] = [1, 72, 0, 4, 6] + self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [1, -1, -1, 0] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, 1, -1, 0, -1] + ) + + # shape: [6, 12, 48, 24] --> [6, 12, 48, 24] + # dims_mapping: [-1, -1, 0, 1] --> [-1, -1, 0, 1], [-1, -1, 0, 1] + self.attrs["shape"] = [6, 12, 48, 24] + self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0, 1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0, 1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0, 1] + ) + + # shape: [6, 12, 48, 24] --> [72, 3, 16, 24] + # dims_mapping: [0, -1, 1, -1] --> [0, -1, 1, -1], [0, 1, -1, -1] + self.attrs["shape"] = [72, 3, 16, 24] + self.x_dist_tensor_spec.set_dims_mapping([0, -1, 1, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, -1, 1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1] + ) + + # shape: [6, 12, 48, 24] --> [72, 3, 16, 24] + # dims_mapping: [1, -1, 0, -1] --> [1, -1, -1, -1], [1, -1, -1, -1] + self.attrs["shape"] = [72, 3, 16, 24] + self.x_dist_tensor_spec.set_dims_mapping([1, -1, 0, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [1, -1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, -1] + ) + + # shape: [6, 12, 48, 24] --> [3, 24, 6, -1, -1] + # raise error + self.attrs["shape"] = [3, 24, 6, -1, -1] + with self.assertRaises(BaseException): + self.rule.infer_forward([self.x_dist_tensor_spec], self.attrs) + + +if __name__ == "__main__": + unittest.main()