Skip to content

Commit

Permalink
[CustomDevice] fix recompute (#53718)
Browse files Browse the repository at this point in the history
  • Loading branch information
ronny1996 committed May 11, 2023
1 parent 793f3b9 commit 2f56b6d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
6 changes: 4 additions & 2 deletions paddle/fluid/distributed/collective/process_group_custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,14 @@ void ProcessGroupCustom::BroadcastUniqueCustomID(
std::vector<phi::ccl::CCLRootId>& ccl_ids) { // NOLINT
if (rank_ == 0) {
for (size_t i = 0; i < ccl_ids.size(); i++) {
auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(i);
auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(gid_) + "/" +
std::to_string(i);
store_->set(key, ccl_ids[i]);
}
} else {
for (size_t i = 0; i < ccl_ids.size(); i++) {
auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(i);
auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(gid_) + "/" +
std::to_string(i);
ccl_ids[i] = store_->get(key);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,10 @@ def forward(

cur_device = paddle.get_device()
assert (
'gpu:' in paddle.get_device() or 'xpu:' in paddle.get_device()
'gpu:' in paddle.get_device()
or 'xpu:' in paddle.get_device()
or cur_device.split(':')[0]
in paddle.device.get_all_custom_device_type()
), "Recompute with RNG is not support current device: {}.".format(
cur_device
)
Expand Down

0 comments on commit 2f56b6d

Please sign in to comment.