Skip to content

Commit

Permalink
[NewExe] sparse tensor support gc (PaddlePaddle#58074)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome authored and jiahy0825 committed Oct 16, 2023
1 parent 295aed6 commit 8742ddc
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,23 @@ void InterpreterCoreFastGarbageCollector::Add(Variable* var) {
for (auto& t : *tensor_arr) {
Add(t.MoveMemoryHolder());
}
} else if (var->IsType<phi::SparseCooTensor>()) {
Add(var->GetMutable<phi::SparseCooTensor>()
->mutable_indices()
->MoveMemoryHolder());
Add(var->GetMutable<phi::SparseCooTensor>()
->mutable_values()
->MoveMemoryHolder());
} else if (var->IsType<phi::SparseCsrTensor>()) {
Add(var->GetMutable<phi::SparseCsrTensor>()
->mutable_cols()
->MoveMemoryHolder());
Add(var->GetMutable<phi::SparseCsrTensor>()
->mutable_crows()
->MoveMemoryHolder());
Add(var->GetMutable<phi::SparseCsrTensor>()
->mutable_values()
->MoveMemoryHolder());
} else if (var->IsType<std::vector<Scope*>>()) {
// NOTE(@xiongkun03) conditional_op / while_op will create a STEP_SCOPE
// refer to executor.cc to see what old garbage collector does.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/fluid/distributed/auto_parallel/dist_attr.h"
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h"
#include "paddle/fluid/framework/new_executor/interpreter/data_transfer.h"
#include "paddle/fluid/framework/new_executor/interpreter/execution_config.h"
Expand Down Expand Up @@ -228,7 +229,9 @@ bool var_can_be_deleted(const std::string& name, const BlockDesc& block) {

return type == proto::VarType::LOD_TENSOR ||
type == proto::VarType::SELECTED_ROWS ||
type == proto::VarType::LOD_TENSOR_ARRAY;
type == proto::VarType::LOD_TENSOR_ARRAY ||
type == proto::VarType::SPARSE_COO ||
type == proto::VarType::SPARSE_CSR;
}

std::unordered_map<const paddle::framework::OperatorBase*,
Expand Down Expand Up @@ -1002,6 +1005,33 @@ void BuildOpFuncList(const platform::Place& place,
if (var->IsType<phi::DenseTensor>()) {
garbages->emplace_back(
var->GetMutable<phi::DenseTensor>()->MoveMemoryHolder());
} else if (var->IsType<phi::SelectedRows>()) {
garbages->emplace_back(var->GetMutable<phi::SelectedRows>()
->mutable_value()
->MoveMemoryHolder());
var->GetMutable<phi::SelectedRows>()->mutable_rows()->clear();
} else if (var->IsType<LoDTensorArray>()) {
auto* tensor_arr = var->GetMutable<LoDTensorArray>();
for (auto& t : *tensor_arr) {
garbages->emplace_back(t.MoveMemoryHolder());
}
} else if (var->IsType<phi::SparseCooTensor>()) {
garbages->emplace_back(var->GetMutable<phi::SparseCooTensor>()
->mutable_indices()
->MoveMemoryHolder());
garbages->emplace_back(var->GetMutable<phi::SparseCooTensor>()
->mutable_values()
->MoveMemoryHolder());
} else if (var->IsType<phi::SparseCsrTensor>()) {
garbages->emplace_back(var->GetMutable<phi::SparseCsrTensor>()
->mutable_cols()
->MoveMemoryHolder());
garbages->emplace_back(var->GetMutable<phi::SparseCsrTensor>()
->mutable_crows()
->MoveMemoryHolder());
garbages->emplace_back(var->GetMutable<phi::SparseCsrTensor>()
->mutable_values()
->MoveMemoryHolder());
}
}
delete garbages; // free mem
Expand All @@ -1022,6 +1052,33 @@ void BuildOpFuncList(const platform::Place& place,
if (var->IsType<phi::DenseTensor>()) {
garbages->emplace_back(
var->GetMutable<phi::DenseTensor>()->MoveMemoryHolder());
} else if (var->IsType<phi::SelectedRows>()) {
garbages->emplace_back(var->GetMutable<phi::SelectedRows>()
->mutable_value()
->MoveMemoryHolder());
var->GetMutable<phi::SelectedRows>()->mutable_rows()->clear();
} else if (var->IsType<LoDTensorArray>()) {
auto* tensor_arr = var->GetMutable<LoDTensorArray>();
for (auto& t : *tensor_arr) {
garbages->emplace_back(t.MoveMemoryHolder());
}
} else if (var->IsType<phi::SparseCooTensor>()) {
garbages->emplace_back(var->GetMutable<phi::SparseCooTensor>()
->mutable_indices()
->MoveMemoryHolder());
garbages->emplace_back(var->GetMutable<phi::SparseCooTensor>()
->mutable_values()
->MoveMemoryHolder());
} else if (var->IsType<phi::SparseCsrTensor>()) {
garbages->emplace_back(var->GetMutable<phi::SparseCsrTensor>()
->mutable_cols()
->MoveMemoryHolder());
garbages->emplace_back(var->GetMutable<phi::SparseCsrTensor>()
->mutable_crows()
->MoveMemoryHolder());
garbages->emplace_back(var->GetMutable<phi::SparseCsrTensor>()
->mutable_values()
->MoveMemoryHolder());
}
}
delete garbages;
Expand Down
18 changes: 17 additions & 1 deletion paddle/fluid/framework/new_executor/new_ir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#ifdef PADDLE_WITH_DNNL
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
Expand Down Expand Up @@ -803,6 +805,18 @@ void NewIRInterpreter::RecordStreamForGC(InstructionBase* instr) {
for (auto& tensor : *tensor_arr) {
TensorRecordStream(tensor);
}
} else if (var->IsType<phi::SparseCooTensor>()) {
TensorRecordStream(
*(var->GetMutable<phi::SparseCooTensor>()->mutable_indices()));
TensorRecordStream(
*(var->GetMutable<phi::SparseCooTensor>()->mutable_values()));
} else if (var->IsType<phi::SparseCsrTensor>()) {
TensorRecordStream(
*(var->GetMutable<phi::SparseCsrTensor>()->mutable_cols()));
TensorRecordStream(
*(var->GetMutable<phi::SparseCsrTensor>()->mutable_crows()));
TensorRecordStream(
*(var->GetMutable<phi::SparseCsrTensor>()->mutable_values()));
} else if (var->IsType<std::vector<Scope*>>()) {
// do nothing
} else {
Expand Down Expand Up @@ -874,7 +888,9 @@ void NewIRInterpreter::CalculateLastLiveOps() {
paddle::framework::Variable* var = inner_scope->FindVar(
value_exe_info_->GetNameById(static_cast<int>(var_id)));
if (var->IsType<phi::DenseTensor>() || var->IsType<phi::SelectedRows>() ||
var->IsType<LoDTensorArray>()) {
var->IsType<LoDTensorArray>() ||
var->IsType<phi::SparseCooTensor>() ||
var->IsType<phi::SparseCsrTensor>()) {
last_live_ops_[var_id].insert(op_idx);
} else {
VLOG(4) << "not clear "
Expand Down
20 changes: 19 additions & 1 deletion paddle/fluid/framework/new_executor/program_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#ifdef PADDLE_WITH_DNNL
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
Expand Down Expand Up @@ -740,7 +742,9 @@ void ProgramInterpreter::Convert(
paddle::framework::Variable* var = inner_scope->FindVar(
var_scope_.GetNameById(static_cast<int>(var_id)));
if (var->IsType<phi::DenseTensor>() || var->IsType<phi::SelectedRows>() ||
var->IsType<LoDTensorArray>()) {
var->IsType<LoDTensorArray>() ||
var->IsType<phi::SparseCooTensor>() ||
var->IsType<phi::SparseCsrTensor>()) {
last_live_ops_[var_id].insert(op_idx);
} else {
VLOG(4) << "not clear "
Expand Down Expand Up @@ -1305,6 +1309,18 @@ void ProgramInterpreter::RecordStreamForGC(const Instruction& instr) {
for (auto& tensor : *tensor_arr) {
TensorRecordStream(tensor);
}
} else if (var->IsType<phi::SparseCooTensor>()) {
TensorRecordStream(
*(var->GetMutable<phi::SparseCooTensor>()->mutable_indices()));
TensorRecordStream(
*(var->GetMutable<phi::SparseCooTensor>()->mutable_values()));
} else if (var->IsType<phi::SparseCsrTensor>()) {
TensorRecordStream(
*(var->GetMutable<phi::SparseCsrTensor>()->mutable_cols()));
TensorRecordStream(
*(var->GetMutable<phi::SparseCsrTensor>()->mutable_crows()));
TensorRecordStream(
*(var->GetMutable<phi::SparseCsrTensor>()->mutable_values()));
} else if (var->IsType<std::vector<Scope*>>()) {
// do nothing
} else {
Expand All @@ -1331,6 +1347,8 @@ void ProgramInterpreter::CheckGC(const Instruction& instr) {
// ignore all persistable var while GC
if (var_scope.VarDesc(static_cast<int>(var_id)) &&
var_scope.VarDesc(static_cast<int>(var_id))->Persistable()) {
VLOG(4) << "Skip persistable var: "
<< var_scope_.GetNameById(static_cast<int>(var_id));
continue;
}
if (is_ready) {
Expand Down

0 comments on commit 8742ddc

Please sign in to comment.