Skip to content

Commit

Permalink
[PIR] add distributed dialect.
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang committed Feb 29, 2024
1 parent ba71b83 commit f26e43c
Show file tree
Hide file tree
Showing 19 changed files with 729 additions and 14 deletions.
6 changes: 6 additions & 0 deletions paddle/fluid/pir/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,12 @@ if(WITH_MKLDNN)
${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_onednn_op.cc)
endif()

file(GLOB_RECURSE dist_dialect_srcs
"${CMAKE_CURRENT_SOURCE_DIR}/distributed/ir/*.cc")

if(WITH_DISTRIBUTE)
set(op_dialect_srcs ${op_dialect_srcs} ${dist_dialect_srcs})
endif()
set(op_dialect_deps phi common pir type_info string_helper)

cc_library(
Expand Down
119 changes: 119 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/attribute_storage.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/common/ddim.h"
#include "paddle/common/dim.h"
#include "paddle/common/hash_funcs.h"
#include "paddle/common/layout.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h"
#include "paddle/phi/common/reduce_type.h"
#include "paddle/pir/include/core/attribute_base.h"
#include "paddle/pir/include/core/builtin_type.h"
#include "paddle/pir/include/core/type.h"
#include "paddle/pir/include/core/type_base.h"
#include "paddle/pir/include/core/utils.h"
#include "paddle/utils/flat_hash_map.h"

namespace paddle {
namespace dialect {

struct ProcessMeshAttrStorage : public pir::AttributeStorage {
///
/// \brief Declare ParamKey according to parameter type.
///
using ParamKey = phi::distributed::ProcessMesh;

ProcessMeshAttrStorage(ParamKey&& process_mesh) // NOLINT
: process_mesh(std::move(process_mesh)) {}

///
/// \brief Each derived TypeStorage must define a Construct method, which
/// StorageManager uses to construct a derived TypeStorage.
///
static ProcessMeshAttrStorage* Construct(ParamKey&& key) {
return new ProcessMeshAttrStorage(std::move(key));
}

///
/// \brief Each derived TypeStorage must provide a HashValue method.
///
static std::size_t HashValue(const ParamKey& key) { return key.hash(); }

///
/// \brief Each derived TypeStorage needs to overload operator==.
///
bool operator==(const ParamKey& key) const { return process_mesh == key; }

ParamKey process_mesh;
};

struct TensorDistAttrStorage : public pir::AttributeStorage {
///
/// \brief Declare ParamKey according to parameter type.
///
using ParamKey = std::tuple<ProcessMeshAttribute,
std::vector<int64_t>,
flat_hash_map<int64_t, phi::ReduceType>>;

TensorDistAttrStorage(ParamKey&& param) // NOLINT
: process_mesh(std::get<0>(param)),
dims_mapping(std::move(std::get<1>(param))),
partial_status(std::move(std::get<2>(param))) {}
///
/// \brief Each derived TypeStorage must define a Construct method, which
/// StorageManager uses to construct a derived TypeStorage.
///
static TensorDistAttrStorage* Construct(ParamKey&& key) {
return new TensorDistAttrStorage(std::move(key));
}

///
/// \brief Each derived TypeStorage must provide a HashValue method.
///
static std::size_t HashValue(const ParamKey& key) {
auto mesh_hash = std::get<0>(key).hash();
auto dims_map_hash = std::hash<std::vector<int64_t>>()(std::get<1>(key));
std::string partial_status_str = "[";
for (auto& itr : std::get<2>(key)) {
partial_status_str +=
"Partial(dims:" + std::to_string(itr.first) + ", " +
phi::ReduceTypeStrings[static_cast<int>(itr.second)] + "), ";
}
partial_status_str += "]";
auto combine_hash = pir::detail::hash_combine(mesh_hash, dims_map_hash);
return pir::detail::hash_combine(
combine_hash, std::hash<std::string>()(partial_status_str));
}

///
/// \brief Each derived TypeStorage needs to overload operator==.
///
bool operator==(const ParamKey& key) const {
return process_mesh == std::get<0>(key) &&
dims_mapping == std::get<1>(key) &&
partial_status == std::get<2>(key);
}

ProcessMeshAttribute process_mesh;
std::vector<int64_t> dims_mapping;
// partial map would less or equal than to mesh.size.
// iterate operation (copy and comparison) would more frequency than random
// element access. <key: dim on mesh, value: reduce type>
flat_hash_map<int64_t, phi::ReduceType> partial_status;
};

} // namespace dialect
} // namespace paddle
73 changes: 73 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_attribute.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h"
#include "paddle/fluid/pir/dialect/distributed/ir/attribute_storage.h"
namespace paddle {
namespace dialect {
///
/// \brief ProcessMeshAttribute interface.
///
const phi::distributed::ProcessMesh& ProcessMeshAttribute::process_mesh()
const {
return storage()->process_mesh;
}
ProcessMeshAttribute ProcessMeshAttribute::get(
pir::IrContext* ctx, const phi::distributed::ProcessMesh& mesh) {
return Base::get(ctx, mesh);
}
ProcessMeshAttribute ProcessMeshAttribute::get(
pir::IrContext* ctx,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& process_ids,
const std::vector<std::string>& dim_names) {
return Base::get(ctx, shape, process_ids, dim_names);
}

///
/// \brief TensorDistAttribute interface.
///
ProcessMeshAttribute TensorDistAttribute::mesh_attr() const {
return storage()->process_mesh;
}
const std::vector<int64_t>& TensorDistAttribute::dims_mapping() const {
return storage()->dims_mapping;
}

std::set<int64_t> TensorDistAttribute::partial_dims() const {
auto& partial = partial_status();
std::set<int64_t> keys;
for (auto& kv : partial) {
keys.emplace(kv.first);
}
return keys;
}

const flat_hash_map<int64_t, phi::ReduceType>&
TensorDistAttribute::partial_status() const {
return storage()->partial_status;
}

TensorDistAttribute TensorDistAttribute::get(
pir::IrContext* ctx,
ProcessMeshAttribute mesh,
const std::vector<int64_t>& dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& partial_status) {
return Base::get(ctx, mesh, dims_mapping, partial_status);
}

} // namespace dialect
} // namespace paddle
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ProcessMeshAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::TensorDistAttribute)
101 changes: 101 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/common/reduce_type.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/pir/include/core/attribute.h"
#include "paddle/pir/include/core/builtin_attribute_storage.h"
#include "paddle/pir/include/core/utils.h"
#include "paddle/utils/flat_hash_map.h"

namespace paddle {
namespace dialect {
class ProcessMeshAttrStorage;
class TensorDistAttrStorage;

class ProcessMeshAttribute : public pir::AttrBase<ProcessMeshAttribute,
pir::Attribute,
ProcessMeshAttrStorage> {
public:
using Base::Base;
const phi::distributed::ProcessMesh& process_mesh() const;
const std::vector<int64_t>& shape() const { return process_mesh().shape(); }
const std::vector<int64_t>& process_ids() const {
return process_mesh().process_ids();
}
const std::vector<std::string>& dim_names() const {
return process_mesh().dim_names();
}
int64_t size() const { return process_mesh().size(); }
int64_t ndim() const { return process_mesh().ndim(); }
int64_t dim_size(int64_t dim) const { return process_mesh().dim_size(dim); }
int64_t dim_size(const std::string& dim_name) const {
return process_mesh().dim_size(dim_name);
}
bool empty() const { return process_mesh().empty(); }
bool contains(int64_t process_id) const {
return process_mesh().contains(process_id);
}
size_t hash() const { return process_mesh().hash(); }

std::string to_string() const { return process_mesh().to_string(); }

static ProcessMeshAttribute get(pir::IrContext* ctx,
const phi::distributed::ProcessMesh& mesh);
static ProcessMeshAttribute get(pir::IrContext* ctx,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& process_ids,
const std::vector<std::string>& dim_names);
};

class TensorDistAttribute : public pir::AttrBase<TensorDistAttribute,
pir::Attribute,
TensorDistAttrStorage> {
public:
using Base::Base;
ProcessMeshAttribute mesh_attr() const;
const phi::distributed::ProcessMesh& process_mesh() const {
return mesh_attr().process_mesh();
}
const std::vector<int64_t>& dims_mapping() const;

// return vector of mesh dims on which the this tensor is partial on
std::set<int64_t> partial_dims() const;

const flat_hash_map<int64_t, phi::ReduceType>& partial_status() const;

static TensorDistAttribute get(
pir::IrContext* ctx,
ProcessMeshAttribute mesh,
const std::vector<int64_t>& dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& partial_status);
static TensorDistAttribute get(
pir::IrContext* ctx,
const phi::distributed::ProcessMesh& mesh,
const std::vector<int64_t>& dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& partial_status) {
return get(ctx,
ProcessMeshAttribute::get(ctx, mesh),
dims_mapping,
partial_status);
}
};

} // namespace dialect
} // namespace paddle

IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ProcessMeshAttribute)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::TensorDistAttribute)
52 changes: 52 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_dialect.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/dialect/distributed/ir/dist_dialect.h"
#include "paddle/fluid/pir/dialect/distributed/ir/attribute_storage.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h"
#include "paddle/fluid/pir/dialect/distributed/ir/type_storage.h"

REGISTER_FILE_SYMBOLS(dist_dialect);
namespace paddle {
namespace dialect {

DistDialect::DistDialect(pir::IrContext *context)
: pir::Dialect(name(), context, pir::TypeId::get<DistDialect>()) {
initialize();
}

void DistDialect::initialize() {
RegisterAttributes<ProcessMeshAttribute, TensorDistAttribute>();
RegisterTypes<DistDenseTensorType>();
}

void DistDialect::PrintType(pir::Type type, std::ostream &os) const {}

void DistDialect::PrintAttribute(pir::Attribute attr, std::ostream &os) const {
if (auto process_mesh_attr = attr.dyn_cast<ProcessMeshAttribute>()) {
os << process_mesh_attr.process_mesh();
} else {
os << "error_attribute_type";
}
}

pir::OpPrintFn DistDialect::PrintOperation(pir::Operation *op) const {
return nullptr;
}

} // namespace dialect
} // namespace paddle

IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DistDialect)
Loading

0 comments on commit f26e43c

Please sign in to comment.