From 2f56b6dac106bc59c5154c327a43eaa035938c14 Mon Sep 17 00:00:00 2001 From: ronnywang Date: Thu, 11 May 2023 18:45:00 +0800 Subject: [PATCH] [CustomDevice] fix recompute (#53718) --- paddle/fluid/distributed/collective/process_group_custom.cc | 6 ++++-- .../paddle/distributed/fleet/recompute/recompute_hybrid.py | 5 ++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/distributed/collective/process_group_custom.cc b/paddle/fluid/distributed/collective/process_group_custom.cc index b6c7063fd6fb7..1e4d1df337bdb 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.cc +++ b/paddle/fluid/distributed/collective/process_group_custom.cc @@ -125,12 +125,14 @@ void ProcessGroupCustom::BroadcastUniqueCustomID( std::vector& ccl_ids) { // NOLINT if (rank_ == 0) { for (size_t i = 0; i < ccl_ids.size(); i++) { - auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(i); + auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(gid_) + "/" + + std::to_string(i); store_->set(key, ccl_ids[i]); } } else { for (size_t i = 0; i < ccl_ids.size(); i++) { - auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(i); + auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(gid_) + "/" + + std::to_string(i); ccl_ids[i] = store_->get(key); } } diff --git a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py index c6acae878745b..a5689020eb009 100644 --- a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py +++ b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py @@ -110,7 +110,10 @@ def forward( cur_device = paddle.get_device() assert ( - 'gpu:' in paddle.get_device() or 'xpu:' in paddle.get_device() + 'gpu:' in paddle.get_device() + or 'xpu:' in paddle.get_device() + or cur_device.split(':')[0] + in paddle.device.get_all_custom_device_type() ), "Recompute with RNG is not support current device: {}.".format( cur_device )