Skip to content

Commit

Permalink
[Semi-Auto] Add reshape spmd rule (#55177)
Browse files Browse the repository at this point in the history
* add reshape spmd rule

* add unit test for reshape spmd rule

* bug fix

* replace the print_info function with to_string

* fix typo

* bug fix

* add handling for "0" in target shape

* remove the part of computing size in dim_trans.cc
  • Loading branch information
pkuzyc authored Aug 14, 2023
1 parent e075a0d commit a97b507
Show file tree
Hide file tree
Showing 8 changed files with 988 additions and 3 deletions.
6 changes: 3 additions & 3 deletions paddle/fluid/distributed/auto_parallel/spmd_rules/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@ std::string GetBroadcastAxes(const int64_t& tenosr_ndim,
TensorDistAttr ReplicatedOnMesh(const TensorDistAttr& src_dist_attr);

// Check whether the given DistTensorSpec objects are valid. For each
// DistTensorSpec, the rank of its dimsmapping must be equal to the rank of its
// DistTensorSpec, the rank of its dims mapping must be equal to the rank of its
// corresponding tensor shape. the parameter op_name is used for logging error
// message.
void VerifySpecs(const std::vector<DistTensorSpec>& specs,
const std::string& op_name);

// Get dimsmapping for the given tensors. Return the pair of each
// tensor's einsum notation and the corresponding dimsmapping.
// Get dims mapping for the given tensors. Return the pair of each
// tensor's einsum notation and the corresponding dims mapping.
std::vector<std::pair<std::string, std::vector<int64_t>>>
GetAxesDimsMappingPair(const std::vector<std::string>& tensor_axes,
const std::vector<DistTensorSpec>& specs);
Expand Down
355 changes: 355 additions & 0 deletions paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,355 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.h"
#include <assert.h>
#include <cstdio>
#include <numeric>
#include <set>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/phi/core/enforce.h"

namespace paddle {
namespace distributed {
namespace auto_parallel {

static std::vector<DimTrans*> all_dim_trans;

DimTrans::DimTrans(Type type) : type_(type) {}

DimTrans::~DimTrans() {}

DimTrans::Type DimTrans::type() const { return type_; }

void DimTrans::set_type(Type type) { type_ = type; }

std::string DimTrans::to_string() { return std::string(""); }

InputDim::InputDim() : DimTrans(DimTrans::Type::INPUTDIM) {
input_dim_ = -1;
all_dim_trans.emplace_back(this);
}

InputDim::InputDim(int64_t dim) : DimTrans(DimTrans::Type::INPUTDIM) {
input_dim_ = dim;
all_dim_trans.emplace_back(this);
}

InputDim::~InputDim() {}

int64_t InputDim::input_dim() const { return input_dim_; }

void InputDim::set_input_dim(int64_t dim) { input_dim_ = dim; }

std::string InputDim::to_string() {
return ("InputDim(" + std::to_string(input_dim_) + ")");
}

Singleton::Singleton() : DimTrans(DimTrans::Type::SINGLETON) {
all_dim_trans.emplace_back(this);
}

std::string Singleton::to_string() { return "Singleton()"; }

Flatten::Flatten() : DimTrans(DimTrans::Type::FLATTEN) {
all_dim_trans.emplace_back(this);
}

Flatten::Flatten(const std::vector<DimTrans*>& dims)
: DimTrans(DimTrans::Type::FLATTEN) {
input_dims_ = dims;
all_dim_trans.emplace_back(this);
}

Flatten::~Flatten() {
input_dims_.assign(input_dims_.size(), nullptr);
std::vector<DimTrans*>().swap(input_dims_);
}

const std::vector<DimTrans*>& Flatten::inputs() const { return input_dims_; }

void Flatten::set_inputs(const std::vector<DimTrans*>& dims) {
input_dims_.assign(dims.begin(), dims.end());
}

std::string Flatten::to_string() {
std::string ret_str("Flatten(");
for (int64_t i = 0, n = input_dims_.size(); i < n; ++i) {
ret_str += input_dims_[i]->to_string();
if (i < n - 1) {
ret_str += ",";
}
}
return ret_str + ")";
}

Split::Split() : DimTrans(DimTrans::Type::SPLIT) {
input_dim_trans_ = nullptr;
all_dim_trans.emplace_back(this);
}

Split::Split(DimTrans* dim, const std::vector<int64_t>& shape, int64_t id)
: DimTrans(DimTrans::Type::SPLIT) {
input_dim_trans_ = dim;
split_id_ = id;
splitted_shape_.assign(shape.begin(), shape.end());
all_dim_trans.emplace_back(this);
}

Split::~Split() {
input_dim_trans_ = nullptr;
std::vector<int64_t>().swap(splitted_shape_);
}

DimTrans* Split::input() const { return input_dim_trans_; }

void Split::set_input(DimTrans* dim) { input_dim_trans_ = dim; }

int64_t Split::split_id() const { return split_id_; }

int64_t Split::local_splitted_shape_value() {
return splitted_shape_[split_id_];
}

std::string Split::to_string() {
std::string ret_str("Split(");
ret_str += input_dim_trans_->to_string() + ", (";
for (int64_t i = 0, n = splitted_shape_.size(); i < n; ++i) {
ret_str += std::to_string(splitted_shape_[i]);
if (i < n - 1) {
ret_str += ",";
}
}
return ret_str + "), " + std::to_string(split_id_) + ")";
}

DimTrans* make_flatten(const std::vector<DimTrans*>& dims) {
DimTrans* ptr = nullptr;
if (dims.size() == 0) {
ptr = new Singleton();
} else if (dims.size() == 1) {
ptr = dims[0];
} else {
ptr = new Flatten(dims);
}
return ptr;
}

DimTrans* make_split(DimTrans* dim,
const std::vector<int64_t>& shape,
int64_t id) {
assert(shape.size() > 0);
DimTrans* ptr = nullptr;
if (shape.size() == 1) {
assert(id == 0);
ptr = dim;
} else if (shape[id] == 1) {
ptr = new Singleton();
} else {
// new shape that remove 1
std::vector<int64_t> new_shape;
// map between from idx in shape to new_shape
std::vector<int64_t> idx_map(shape.size(), -1);
for (int64_t i = 0, n = shape.size(); i < n; ++i) {
if (shape[id] != 1) {
idx_map[i] = new_shape.size();
new_shape.emplace_back(shape[i]);
}
}
ptr = new Split(dim, new_shape, idx_map[id]);
}
return ptr;
}

void CleanUp() {
for (int64_t i = 0, n = all_dim_trans.size(); i < n; i++) {
if (all_dim_trans[i]) {
delete all_dim_trans[i];
all_dim_trans[i] = nullptr;
}
}
std::vector<DimTrans*>().swap(all_dim_trans);
}

// Given a `dim_trans` of an output axis, get the input axis
// whose dim mapping should be propogated to it.
// If the returned input axis is none, the output axis's
// dim mapping should be set to -1 (replicated). For an axis
// that is flattened from input axes, return the leftmost
// flattened input axis. For the split transformation,
// only the leftmost split axis in output will return its input.
DimTrans* GetDimTrans(DimTrans* dim_trans,
std::vector<std::vector<bool>>* shardable,
std::set<int64_t>* seen_dims,
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& mesh_shape,
const std::vector<int64_t>& input_dims_mapping,
const std::set<int64_t>& sharded_input_dims) {
DimTrans::Type type = dim_trans->type();
DimTrans* ret_dim_trans = nullptr;

if (type == DimTrans::Type::INPUTDIM) {
InputDim* inputdim = dynamic_cast<InputDim*>(dim_trans);
int64_t dim = inputdim->input_dim();
seen_dims->insert(dim);

if (sharded_input_dims.count(dim) > 0) {
ret_dim_trans = dim_trans;
}
} else if (type == DimTrans::Type::FLATTEN) {
Flatten* flatten = dynamic_cast<Flatten*>(dim_trans);
const std::vector<DimTrans*>& inputs = flatten->inputs();
int64_t nmesh = (*shardable)[0].size();
for (int64_t i = 1, n = inputs.size(); i < n; i++) {
DimTrans* input = inputs[i];
if (input->type() == DimTrans::Type::INPUTDIM) {
(*shardable)[i].assign(nmesh, false);
}

GetDimTrans(input,
shardable,
seen_dims,
input_shape,
mesh_shape,
input_dims_mapping,
sharded_input_dims);
}

DimTrans* dim0 = inputs[0];
if (dim0->type() == DimTrans::Type::INPUTDIM) {
InputDim* inputdim = dynamic_cast<InputDim*>(dim0);
if (sharded_input_dims.count(inputdim->input_dim()) > 0) {
ret_dim_trans = dim0;
}
}
} else if (type == DimTrans::Type::SPLIT) {
Split* split = dynamic_cast<Split*>(dim_trans);
DimTrans* dim = GetDimTrans(split->input(),
shardable,
seen_dims,
input_shape,
mesh_shape,
input_dims_mapping,
sharded_input_dims);
int64_t ret_size = split->local_splitted_shape_value();

if (split->split_id() == 0) {
if (dim != nullptr) {
PADDLE_ENFORCE_EQ(dim->type(),
DimTrans::Type::INPUTDIM,
phi::errors::InvalidArgument(
"The returned dim_trans must be INPUTDIM."));
InputDim* inputdim = dynamic_cast<InputDim*>(dim);
int64_t nmesh = mesh_shape.size();
int64_t input_axis = inputdim->input_dim();

// Check whether the sharded dim can be sharded on
// each mesh dimension. The dimension should be
// divisible by the mesh size that it is sharded on
for (int64_t imesh = 0; imesh < nmesh; imesh++) {
(*shardable)[input_axis][imesh] = (ret_size % mesh_shape[imesh] == 0);
}
}
ret_dim_trans = dim;
}
} else if (type == DimTrans::Type::SINGLETON) {
ret_dim_trans = nullptr;
}
return ret_dim_trans;
}

void GetUsedInputDim(DimTrans* dim_trans, std::set<int64_t>* seen_dims) {
if (dim_trans->type() == DimTrans::Type::INPUTDIM) {
InputDim* input = dynamic_cast<InputDim*>(dim_trans);
seen_dims->insert(input->input_dim());
} else if (dim_trans->type() == DimTrans::Type::FLATTEN) {
Flatten* flatten = dynamic_cast<Flatten*>(dim_trans);
for (DimTrans* trans : flatten->inputs()) {
GetUsedInputDim(trans, seen_dims);
}
} else if (dim_trans->type() == DimTrans::Type::SPLIT) {
Split* split = dynamic_cast<Split*>(dim_trans);
GetUsedInputDim(split->input(), seen_dims);
} else {
return;
}
}

std::vector<std::vector<int64_t>> InferFromDimTrans(
const DistTensorSpec& input_spec, const std::vector<DimTrans*>& dim_trans) {
const std::vector<int64_t>& input_shape = input_spec.shape();
const std::vector<int64_t>& input_dims_mapping = input_spec.dims_mapping();
const ProcessMesh& mesh = input_spec.dist_attr().process_mesh();
const std::vector<int64_t>& mesh_shape = mesh.shape();

std::set<int64_t> sharded_input_dims;
for (int64_t i = 0, n = input_dims_mapping.size(); i < n; ++i) {
if (input_dims_mapping[i] > -1) {
sharded_input_dims.insert(i);
}
}
int64_t ndim = input_shape.size();
int64_t nmesh = mesh_shape.size();
std::vector<std::vector<bool>> shardable(ndim,
std::vector<bool>(nmesh, true));

std::set<int64_t> seen_input_dims;
for (DimTrans* trans : dim_trans) {
GetUsedInputDim(trans, &seen_input_dims);
}

for (int64_t idim = 0; idim < ndim; idim++) {
bool seen = seen_input_dims.count(idim);
if (!seen) {
shardable[idim].assign(nmesh, seen);
}
}

// get the map from sharded input dimensions to output dimensions.
std::vector<int64_t> dim_map_src2tgt(ndim, -1);
for (int64_t i = 0, n = dim_trans.size(); i < n; i++) {
DimTrans* dim = GetDimTrans(dim_trans[i],
&shardable,
&seen_input_dims,
input_shape,
mesh_shape,
input_dims_mapping,
sharded_input_dims);
if (dim != nullptr && dim->type() == DimTrans::Type::INPUTDIM) {
InputDim* inputdim = dynamic_cast<InputDim*>(dim);
dim_map_src2tgt[inputdim->input_dim()] = i;
}
}

std::vector<int64_t> out_dims_mapping(dim_trans.size(), -1);
std::vector<int64_t> new_input_dims_mapping(input_dims_mapping);

// set output dims mapping with corresponding input dimensions.
// if one input dimension is sharded on a unshardable mesh after
// splitting, we need to make it replicated.
for (int64_t i = 0; i < ndim; i++) {
int64_t mesh_dim = input_dims_mapping[i];
if (mesh_dim > -1 && shardable[i][mesh_dim] && dim_map_src2tgt[i] > -1) {
out_dims_mapping[dim_map_src2tgt[i]] = input_dims_mapping[i];
} else {
new_input_dims_mapping[i] = -1;
}
}

return {new_input_dims_mapping, out_dims_mapping};
}

} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
Loading

0 comments on commit a97b507

Please sign in to comment.