Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/develop' into rm_fluid_gpu_dev…
Browse files Browse the repository at this point in the history
…ice_function_h
  • Loading branch information
huangjiyi committed Nov 21, 2022
2 parents 3762648 + 1ba308f commit 9e59973
Show file tree
Hide file tree
Showing 124 changed files with 4,467 additions and 3,248 deletions.
1 change: 1 addition & 0 deletions AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ This is an incomplete list of authors of [Paddle](https://github.com/PaddlePaddl
| dragonwarrior | Long Wang |
| dyning | Yuning Du |
| emailweixu | Wei Xu |
| engineer1109 | Jia-Liang Wang |
| gangliao | Gang Liao |
| gongweibao | Wei-Bao Gong |
| guru4elephant | Daxiang Dong |
Expand Down
18 changes: 3 additions & 15 deletions paddle/fluid/distributed/collective/BKCLTools.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,23 +77,11 @@ class XPUEventManager {
device_index_));

platform::XPUDeviceGuard guard(device_index_);
PADDLE_ENFORCE_XPU_SUCCESS(xpu_event_record(event_, ctx.stream()));
// TODO(zhangxiaoci) temporary solution: xpu::event seems buggy
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(ctx.stream()));
}

void Block(const XPUContext& ctx) const {
if (is_created_) {
auto device_index = ctx.GetPlace().device;
PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet(
"XPUContext's device %d does not match"
"Event's device %d",
device_index,
device_index_));
platform::XPUDeviceGuard guard(device_index_);
PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_wait_event(ctx.stream(), event_));
}
}
void Block(const XPUContext& ctx) const {}

private:
bool is_created_{false};
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/distributed/collective/NCCLTools.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
namespace paddle {
namespace distributed {

#define NCCLCHECK(cmd) \
#define NCCL_CHECK(cmd) \
do { \
ncclResult_t r = cmd; \
if (r != ncclSuccess) { \
Expand All @@ -60,6 +60,7 @@ namespace distributed {
} while (0)

ncclRedOp_t ToNCCLRedType(ReduceOp reduction);

std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID);

} // namespace distributed
Expand Down
16 changes: 5 additions & 11 deletions paddle/fluid/distributed/collective/ProcessGroup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,6 @@ void ProcessGroup::Task::Synchronize() {}

void ProcessGroup::Task::UpdateWaitChain(const phi::DeviceContext& ctx) {}

ProcessGroup::ProcessGroup(int rank,
int size,
const platform::Place& place,
int gid)
: rank_(rank), size_(size), place_(place), gid_(gid) {
if (gid != IGNORE_ID) {
auto map = ProcessGroupMapFromGid::getInstance();
map->insert(gid_, this);
}
}

ProcessGroup::ProcessGroup(int rank, int size, int gid)
: rank_(rank), size_(size), gid_(gid) {
if (gid != IGNORE_ID) {
Expand All @@ -66,5 +55,10 @@ ProcessGroup::Task::Task(int rank,
bool sync_op)
: rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {}

ProcessGroupIdMap& ProcessGroupIdMap::GetInstance() {
static ProcessGroupIdMap instance;
return instance;
}

} // namespace distributed
} // namespace paddle
116 changes: 48 additions & 68 deletions paddle/fluid/distributed/collective/ProcessGroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,21 +82,16 @@ class ProcessGroup {
};

public:
explicit ProcessGroup(int rank, int size, int gid);
ProcessGroup(int rank, int size, int gid);
virtual ~ProcessGroup() = default;
// TODO(dev): This constructor will be removed later.
explicit ProcessGroup(int rank,
int size,
const platform::Place& place,
int gid);

int GetRank() const { return rank_; }

int GetSize() const { return size_; }

virtual std::string GetBackendName() const = 0;

virtual const phi::DeviceContext& GetDeviceContext(const Place& place) const {
virtual phi::DeviceContext* GetDeviceContext(const Place& place) const {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support get device_context.",
GetBackendName()));
Expand Down Expand Up @@ -150,6 +145,36 @@ class ProcessGroup {
GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Reduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support reduce with sync_op flag.",
GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support reduce_scatter with sync_op flag.",
GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Scatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support scatter with sync_op flag.",
GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
Expand All @@ -160,11 +185,12 @@ class ProcessGroup {
GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Send(phi::DenseTensor*,
int dst_rank,
int64_t offset,
int64_t numel,
bool sync_op) {
virtual std::shared_ptr<ProcessGroup::Task> Send(
const phi::DenseTensor& tensor,
int dst_rank,
int64_t offset,
int64_t numel,
bool sync_op) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support send with sync_op flag.",
GetBackendName()));
Expand Down Expand Up @@ -214,26 +240,12 @@ class ProcessGroup {
"ProcessGroup%s does not support send", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>&, int, bool) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send with sync_op flag",
GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>&, int) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support recv", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>&, int, bool) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support recv with sync_op flag",
GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&) { // NOLINT
Expand All @@ -257,14 +269,6 @@ class ProcessGroup {
"ProcessGroup%s does not support AllToAll", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
bool) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support alltoall", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
Expand All @@ -273,16 +277,6 @@ class ProcessGroup {
"ProcessGroup%s does not support reduce", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
const ReduceOptions&,
bool) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support reduce with sync_op flag",
GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
Expand All @@ -291,33 +285,19 @@ class ProcessGroup {
"ProcessGroup%s does not support scatter", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
const ScatterOptions&,
bool) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support scatter with sync_op flag",
GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> ReduceScatter(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
const ReduceScatterOptions&,
bool) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support reduce_scatter with sync_op flag",
GetBackendName()));
}

protected:
const int rank_;
const int size_;
const platform::Place place_;
const int gid_;
int rank_;
int size_;
int gid_;
};

class ProcessGroupIdMap
: public std::unordered_map<int, std::shared_ptr<ProcessGroup>> {
public:
static ProcessGroupIdMap& GetInstance();
};

// TODO(dev): The following method will be removed soon.
class ProcessGroupMapFromGid {
public:
bool has(int gid) {
Expand Down
19 changes: 17 additions & 2 deletions paddle/fluid/distributed/collective/ProcessGroupBKCL.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,14 @@ bool ProcessGroupBKCL::BKCLTask::Wait(std::chrono::milliseconds timeout) {

if (barrier_) {
// If we use the work to do barrier, we should block cpu

// TODO(zhangxiaoci) There is no such function that can sync entire device
// for xpu (for now), so all we can do is sync whatever stream that we know
// and hope for the best. Note that for correctness the communication stream
// needs to be in sync mode.
platform::XPUDeviceGuard guard(place_.GetDeviceId());
xpu_wait();
calc_ctx->Wait();
}
return true;
}
Expand Down Expand Up @@ -105,6 +111,7 @@ void ProcessGroupBKCL::BroadcastUniqueBKCLID(BKCLUniqueId* bkcl_id) {

void ProcessGroupBKCL::CreateBKCLEnvCache(const Place& place,
const std::string& place_key) {
platform::XPUDeviceGuard guard(place.GetDeviceId());
BKCLUniqueId bkcl_id;
if (rank_ == 0) {
PADDLE_ENFORCE_XPU_SUCCESS(bkcl_get_unique_id(&bkcl_id));
Expand Down Expand Up @@ -275,12 +282,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Barrier(
return task;
}

const phi::DeviceContext& ProcessGroupBKCL::GetDeviceContext(
phi::DeviceContext* ProcessGroupBKCL::GetDeviceContext(
const Place& place) const {
return GetDeviceContext(place, /*use_calc_stream*/ false);
}

const phi::DeviceContext& ProcessGroupBKCL::GetDeviceContext(
phi::DeviceContext* ProcessGroupBKCL::GetDeviceContext(
const Place& place, bool use_calc_stream) const {
const std::string& key = GetKeyFromPlace(place);
if (use_calc_stream) {
Expand Down Expand Up @@ -524,5 +531,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
/*use_calc_stream*/ false);
}

std::shared_ptr<ProcessGroupBKCL> ProcessGroupBKCL::CreateProcessGroupBKCL(
const std::shared_ptr<Store>& store, int rank, int size, int gid) {
auto process_group =
std::make_shared<ProcessGroupBKCL>(store, rank, size, gid);
ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
return process_group;
}

} // namespace distributed
} // namespace paddle
9 changes: 6 additions & 3 deletions paddle/fluid/distributed/collective/ProcessGroupBKCL.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,17 @@ class ProcessGroupBKCL : public ProcessGroupStream {
int size,
int gid);

static std::shared_ptr<ProcessGroupBKCL> CreateProcessGroupBKCL(
const std::shared_ptr<Store>& store, int rank, int size, int gid);

std::string GetBackendName() const override {
return std::string(BKCL_BACKEND_NAME);
}

const phi::DeviceContext& GetDeviceContext(const Place& place) const override;
phi::DeviceContext* GetDeviceContext(const Place& place) const override;

const phi::DeviceContext& GetDeviceContext(
const Place& place, bool use_calc_stream) const override;
phi::DeviceContext* GetDeviceContext(const Place& place,
bool use_calc_stream) const override;

std::shared_ptr<ProcessGroup::Task> AllReduce(
phi::DenseTensor* out_tensor,
Expand Down
17 changes: 15 additions & 2 deletions paddle/fluid/distributed/collective/ProcessGroupCustom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(
return task;
}

const phi::DeviceContext& ProcessGroupCustom::GetDeviceContext(
phi::DeviceContext* ProcessGroupCustom::GetDeviceContext(
const Place& place) const {
const std::string key = GetKeyFromPlace(place);
const auto& iter = places_to_ctx_.find(key);
Expand All @@ -308,7 +308,7 @@ const phi::DeviceContext& ProcessGroupCustom::GetDeviceContext(
places_to_ctx_.end(),
platform::errors::NotFound(
"Cannot find the device context in this process group."));
return *iter->second[0];
return iter->second[0].get();
}

phi::ccl::CCLComm ProcessGroupCustom::CustomCCLComm(const Place& place) const {
Expand Down Expand Up @@ -433,5 +433,18 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
CommType::BROADCAST);
}

std::shared_ptr<ProcessGroupCustom>
ProcessGroupCustom::CreateProcessGroupCustom(
const std::shared_ptr<Store>& store,
const std::string& device_type,
int rank,
int size,
int gid) {
auto process_group =
std::make_shared<ProcessGroupCustom>(store, device_type, rank, size, gid);
ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
return process_group;
}

} // namespace distributed
} // namespace paddle
9 changes: 8 additions & 1 deletion paddle/fluid/distributed/collective/ProcessGroupCustom.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ class ProcessGroupCustom : public ProcessGroup {
int size,
int gid);

static std::shared_ptr<ProcessGroupCustom> CreateProcessGroupCustom(
const std::shared_ptr<Store>& store,
const std::string& device_type,
int rank,
int size,
int gid);

std::string GetBackendName() const override { return "XCCL_" + device_type_; }

std::shared_ptr<ProcessGroup::Task> AllGather(
Expand All @@ -93,7 +100,7 @@ class ProcessGroupCustom : public ProcessGroup {
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;

const phi::DeviceContext& GetDeviceContext(const Place& place) const override;
phi::DeviceContext* GetDeviceContext(const Place& place) const override;

phi::ccl::CCLComm CustomCCLComm(const Place& place) const;

Expand Down
Loading

0 comments on commit 9e59973

Please sign in to comment.