Skip to content

Commit

Permalink
add UpdateWaitChain for process_group_custom (PaddlePaddle#51491)
Browse files Browse the repository at this point in the history
* add UpdateWaitChain for  process_group_custom

* add UpdateWaitChain for  process_group_custom
  • Loading branch information
ronny1996 committed Mar 13, 2023
1 parent 524eeb1 commit 9751bd0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
12 changes: 12 additions & 0 deletions paddle/fluid/distributed/collective/process_group_custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,18 @@ bool ProcessGroupCustom::CustomTask::Wait(std::chrono::milliseconds timeout) {
// Same as Wait
void ProcessGroupCustom::CustomTask::Synchronize() { Wait(kWaitTimeout); }

void ProcessGroupCustom::CustomTask::UpdateWaitChain(
const phi::DeviceContext& ctx) {
PADDLE_ENFORCE_NE(
std::find(places_.cbegin(), places_.cend(), ctx.GetPlace()),
places_.cend(),
phi::errors::NotFound("Cannot find the device context in this task."));
auto index = std::find(places_.cbegin(), places_.cend(), ctx.GetPlace()) -
places_.cbegin();
control_events_[index].Record(
reinterpret_cast<const phi::CustomContext&>(ctx));
}

ProcessGroupCustom::ProcessGroupCustom(
const std::shared_ptr<phi::distributed::Store>& store,
const std::string& device_type,
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/distributed/collective/process_group_custom.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ class ProcessGroupCustom : public ProcessGroupWithoutStream {
CommType CommType,
const std::vector<phi::DenseTensor>& inputs);

bool IsCompleted();
bool IsCompleted() override;
void SynchronizeStreams();
bool Wait(std::chrono::milliseconds timeout = kWaitTimeout);
void Synchronize();
bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) override;
void Synchronize() override;
void UpdateWaitChain(const phi::DeviceContext& ctx) override;
void SetOutputs(std::vector<phi::DenseTensor>& outputs); // NOLINT
virtual ~CustomTask();

Expand Down

0 comments on commit 9751bd0

Please sign in to comment.