Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add trace hang function #59217

Merged
merged 13 commits into from
Nov 28, 2023
6 changes: 6 additions & 0 deletions paddle/fluid/distributed/collective/process_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ ProcessGroup::ProcessGroup(int rank, int size, int gid)
auto map = ProcessGroupMapFromGid::getInstance();
map->insert(gid_, this);
}
const char* global_rank = std::getenv("PADDLE_TRAINER_ID");
PADDLE_ENFORCE_NOT_NULL(
global_rank,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINER_ID' cannot be found."));
global_rank_ = std::atoi(global_rank);
}

// TODO(sunyilun): methods below will be removed later
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/distributed/collective/process_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ class ProcessGroup {
}

protected:
int global_rank_{-1};
int rank_;
int size_;
int gid_;
Expand Down
77 changes: 69 additions & 8 deletions paddle/fluid/distributed/collective/process_group_nccl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,19 @@
// limitations under the License.

#include "paddle/fluid/distributed/collective/process_group_nccl.h"

#include "paddle/fluid/distributed/collective/common.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/distributed/check/nccl_dynamic_check.h"
#include "paddle/phi/core/distributed/check/static_check.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/comm_task_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_task.h"
#include "paddle/phi/core/distributed/nccl_tools.h"
#include "paddle/phi/core/distributed/trace_utils.h"
#include "paddle/phi/core/distributed/utils.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/flags.h"
Expand All @@ -46,8 +45,6 @@ namespace paddle {
namespace distributed {

using phi::distributed::CheckSizeOnEachRank;
using phi::distributed::GetTraceEndKey;
using phi::distributed::GetTraceStartKey;
using phi::distributed::IsP2POP;
using phi::distributed::NCCLDTypeToString;
using phi::distributed::NCCLRedTypeToString;
Expand Down Expand Up @@ -119,6 +116,13 @@ ProcessGroupNCCL::ProcessGroupNCCL(
pg_timeout_(timeout) {
LOG(INFO) << "ProcessGroupNCCL pg_timeout_ " << pg_timeout_;
}
ProcessGroupNCCL::~ProcessGroupNCCL() {
LOG(INFO) << "ProcessGroupNCCL destruct ";
if (FLAGS_enable_async_trace) {
auto& comm_task_manager = phi::distributed::CommTaskManager::GetInstance();
comm_task_manager.Stop();
}
}

void ProcessGroupNCCL::GroupStart() {
NCCL_CHECK(phi::dynload::ncclGroupStart());
Expand Down Expand Up @@ -674,6 +678,7 @@ void ProcessGroupNCCL::GetStoreKey(const std::string& place_key,
} else {
*store_key = "nccl_ids/" + std::to_string(gid_) + "/" + place_key;
}
place_to_group_key_[place_key] = *store_key;
}

void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
Expand Down Expand Up @@ -711,6 +716,50 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
auto comm_ctx = std::make_unique<phi::GPUContext>(place);
comm_ctx->set_nccl_comm(nccl_comm_ctx->GetNcclComm());

if (FLAGS_enable_async_trace) {
// gather global ranks in current group
size_t gpu_global_rank_size = sizeof(int);
auto gpu_global_rank = phi::memory_utils::Alloc(
phi::GPUPlace(phi::backends::gpu::GetCurrentDeviceId()),
gpu_global_rank_size);

phi::memory_utils::Copy(phi::GPUPlace(),
gpu_global_rank->ptr(),
phi::CPUPlace(),
&global_rank_,
gpu_global_rank_size);

size_t gpu_global_ranks_size = num_ranks * sizeof(int);
auto gpu_global_ranks = phi::memory_utils::Alloc(
phi::GPUPlace(phi::backends::gpu::GetCurrentDeviceId()),
gpu_global_ranks_size);

NCCL_CHECK(phi::dynload::ncclAllGather(gpu_global_rank->ptr(),
gpu_global_ranks->ptr(),
1,
ncclInt,
nccl_comm_ctx->GetNcclComm(),
comm_ctx->stream()));

std::vector<int> global_ranks(num_ranks);
phi::memory_utils::Copy(phi::CPUPlace(),
global_ranks.data(),
phi::GPUPlace(),
gpu_global_ranks->ptr(),
gpu_global_ranks_size);

// store global_ranks in current group_key
std::once_flag flag;
std::call_once(flag, [this]() {
phi::distributed::CommContextManager::GetInstance().SetStore(store_);
phi::distributed::CommTaskManager::GetInstance().SetTimeout(pg_timeout_);
});

std::string group_key = place_to_group_key_.at(place_key);
phi::distributed::CommContextManager::GetInstance().AddGroupRanks(
group_key, global_ranks);
}

auto* calc_ctx = static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(place));

Expand Down Expand Up @@ -771,8 +820,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
if (!FLAGS_enable_async_trace) {
fn(nccl_comm_ctx, nccl_stream);
} else {
std::string group_key = place_to_group_key_.at(key);
auto comm_task =
std::make_shared<phi::distributed::NCCLCommTask>(place,
group_key,
rank_,
size_,
gid_,
Expand Down Expand Up @@ -837,16 +888,19 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Point2Point(
bool is_batch_p2p = s_group_call_counter > 0;
std::string key = "";

int p2p_nrank = 0;
if (is_batch_p2p) {
key = GetKeyFromPlace(place);
p2p_rank = rank_;
p2p_target_rank = peer;
p2p_nrank = GetSize();
} else {
int low_rank = rank_ < peer ? rank_ : peer;
int high_rank = rank_ < peer ? peer : rank_;
key = std::to_string(low_rank) + "->" + std::to_string(high_rank);
p2p_rank = rank_ < peer ? 0 : 1;
p2p_target_rank = 1 - p2p_rank;
p2p_nrank = 2;
}

platform::CUDADeviceGuard cuda_guard(place);
Expand All @@ -857,6 +911,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Point2Point(
if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateNCCLEnvCache(place, key, store_key, comm_type, p2p_rank);
}
if (p2p_comm_seq_.find(key) == p2p_comm_seq_.end()) {
p2p_comm_seq_[key] = 0;
}
p2p_comm_seq_[key]++;

if (!use_calc_stream) {
SyncCalcStream(place, key);
Expand All @@ -869,18 +927,21 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Point2Point(
auto nccl_comm = comm_ctx->nccl_comm();
auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();

std::string group_key = place_to_group_key_.at(key);
auto comm_task =
std::make_shared<phi::distributed::NCCLCommTask>(place,
rank_,
size_,
group_key,
p2p_rank,
p2p_nrank,
gid_,
comm_seq_,
p2p_comm_seq_[key],
tensor_tmp.numel(),
sync_op,
use_calc_stream,
nccl_comm,
nccl_stream,
comm_type);
comm_type,
pg_timeout_);

auto nccl_comm_ctx = this->GetCommContext(&store_key);

Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/distributed/collective/process_group_nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
int size,
int gid,
int64_t timeout = 30 * 60 * 1000);
~ProcessGroupNCCL();

std::string GetBackendName() const override { return "NCCL"; }

Expand Down Expand Up @@ -220,6 +221,8 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
place_to_comm_ctx_;

uint64_t comm_seq_{0};
std::unordered_map<std::string, uint64_t> p2p_comm_seq_;
std::unordered_map<std::string, std::string> place_to_group_key_;

// TODO(sunyilun): attrs below will be removed later
std::mutex mutex_;
Expand Down
20 changes: 20 additions & 0 deletions paddle/phi/core/distributed/comm_context_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,5 +208,25 @@ bool CommContextManager::Has(const std::string& unique_comm_key) const {
return id_to_comm_context_.find(unique_comm_key) != id_to_comm_context_.end();
}

void CommContextManager::SetGroupSize(const std::string& pg_key, int size) {
pg_key_size_[pg_key] = size;
}

void CommContextManager::AddGroupRanks(const std::string& pg_key,
std::vector<int> global_ranks) {
if (pg_key_ranks_.find(pg_key) == pg_key_ranks_.end()) {
pg_key_ranks_[pg_key] = global_ranks;
}
}

std::vector<int> CommContextManager::GetGroupRanks(
const std::string& pg_key) const {
PADDLE_ENFORCE_NE(
pg_key_ranks_.find(pg_key),
pg_key_ranks_.end(),
errors::NotFound("Can not find pg_key %d in GroupRanks.", pg_key));
return pg_key_ranks_.at(pg_key);
}

} // namespace distributed
} // namespace phi
12 changes: 12 additions & 0 deletions paddle/phi/core/distributed/comm_context_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/phi/common/place.h"
#include "paddle/phi/core/distributed/comm_context.h"
Expand Down Expand Up @@ -64,6 +65,12 @@ class CommContextManager {

static void SetDeviceId(int dev_id);

void SetGroupSize(const std::string& pg_key, int size);

void AddGroupRanks(const std::string& pg_key, std::vector<int> global_ranks);

std::vector<int> GetGroupRanks(const std::string& pg_key) const;

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
static void CreateNCCLCommContext(const std::shared_ptr<Store>& store,
const std::string& unique_comm_key,
Expand Down Expand Up @@ -96,6 +103,11 @@ class CommContextManager {
id_to_comm_context_;
std::shared_ptr<Store> store_;
static int device_id;

// process group key to global ranks map
std::unordered_map<std::string, std::vector<int>> pg_key_ranks_;
// process group key to group size map
std::unordered_map<std::string, int> pg_key_size_;
};

} // namespace distributed
Expand Down
26 changes: 25 additions & 1 deletion paddle/phi/core/distributed/comm_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class CommTask {
public:
CommTask(const std::string& backend = "",
const phi::Place& place = phi::Place(),
const std::string& group_key = "",
int rank = -1,
int size = 0,
int gid = 0,
Expand All @@ -47,6 +48,7 @@ class CommTask {
CommType comm_type = CommType::UNKNOWN)
: backend_(backend),
place_(place),
group_key_(group_key),
rank_(rank),
size_(size),
gid_(gid),
Expand All @@ -65,10 +67,11 @@ class CommTask {
virtual ~CommTask() = default;

std::string UniqueKey() {
return "op:" + CommTypeToString(comm_type_) +
return "group_key:" + group_key_ + ",op:" + CommTypeToString(comm_type_) +
",gid:" + std::to_string(gid_) + ",seq:" + std::to_string(seq_);
}

std::string GroupKey() { return group_key_; }
std::string GetBackend() { return backend_; }
phi::Place GetPlace() { return place_; }
int GetGlobalRank() { return global_rank_; }
Expand Down Expand Up @@ -105,6 +108,12 @@ class CommTask {
return;
}

virtual void ClearRecord() {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
return;
}

virtual std::string GetCommErrors() {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
Expand All @@ -125,6 +134,16 @@ class CommTask {
phi::errors::Unimplemented("%s is not implemented.", __func__));
return false;
}
virtual void SetUpdated(bool updated) {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
return;
}
virtual bool IsUpdated() {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
return false;
}
virtual void AbortComm() {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
Expand All @@ -134,6 +153,7 @@ class CommTask {
protected:
std::string backend_;
phi::Place place_;
std::string group_key_;
int global_rank_;
int rank_;
int size_;
Expand All @@ -145,7 +165,11 @@ class CommTask {
CommType comm_type_;
bool start_trace_updated_{false};

// task status
bool started_ = false;
bool completed_ = false;
// task status changed
bool updated_ = true;
bool aborted_{false};
std::chrono::time_point<std::chrono::steady_clock> start_time_;
std::shared_ptr<Store> store_;
Expand Down
Loading