Skip to content

Commit

Permalink
[PIR] Refine record stream for new_ir_interpreter gc (#58337)
Browse files Browse the repository at this point in the history
* refine

* fix
  • Loading branch information
zhangbo9674 authored Oct 25, 2023
1 parent 7adb2f0 commit baf8032
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions paddle/fluid/framework/new_executor/new_ir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/pir/core/builtin_attribute.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif

PHI_DECLARE_bool(enable_new_ir_in_executor);
PHI_DECLARE_bool(enable_new_ir_in_executor_trace_run);
Expand Down Expand Up @@ -730,6 +737,29 @@ void NewIRInterpreter::RecordStreamForGC(InstructionBase* instr) {

gpuStream_t stream =
reinterpret_cast<const phi::GPUContext&>(instr->DeviceContext()).stream();
// TODO(lizhiyu): Only analyse the 'send_v2' for GPT pp strategy right now.
// To support all the operators for communicating in the future.
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (instr->Name() == "pd_op.send_v2") {
::pir::Operation* op = instr->Operation();
if (op->HasAttribute("use_calc_stream") &&
op->attribute<::pir::BoolAttribute>("use_calc_stream").data() ==
false) {
int ring_id = op->attribute<::pir::Int32Attribute>("ring_id").data();
if (FLAGS_dynamic_static_unified_comm) {
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
stream = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(ring_id)))
->GetStream();
} else {
stream = platform::NCCLCommContext::Instance()
.Get(ring_id, instr->DeviceContext().GetPlace())
->stream();
}
}
}
#endif
auto TensorRecordStream = [&stream](phi::DenseTensor& tensor) {
auto allocation = tensor.Holder();
if (allocation == nullptr) {
Expand Down

0 comments on commit baf8032

Please sign in to comment.