Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#34 from zyfncg/drr_pass
Browse files Browse the repository at this point in the history
Merge develop
  • Loading branch information
yuanlehome authored Sep 22, 2023
2 parents 43d5e48 + 85622b9 commit b731d39
Show file tree
Hide file tree
Showing 254 changed files with 4,680 additions and 1,855 deletions.
2 changes: 1 addition & 1 deletion .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ bugprone-argument-comment,
-bugprone-assert-side-effect,
-bugprone-bad-signal-to-kill-thread,
-bugprone-bool-pointer-implicit-conversion,
-bugprone-branch-clone,
bugprone-branch-clone,
bugprone-copy-constructor-init,
-bugprone-dangling-handle,
-bugprone-dynamic-static-initializers,
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/ast_gen_ius/tensor_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ TensorGroup::TensorGroup(const std::vector<ir::Tensor>& tensors) {

for (auto& tensor : tensors) {
output_tensor_names_.insert(tensor->name);
std::set<ir::Expr> used_tensors = ir::CollectIRNodes(
std::set<ir::Expr> used_tensors = ir::ir_utils::CollectIRNodes(
tensor->body(), [](const Expr* x) { return x->as_tensor(); });
for (const Expr& x : used_tensors) {
const ir::Tensor to_dep = x.as_tensor_ref();
Expand Down
37 changes: 19 additions & 18 deletions paddle/cinn/auto_schedule/analysis/analyze_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,29 +54,30 @@ void AnalyzeScheduleBlockReadWriteBuffer(ir::ScheduleBlock* sche_block) {
return;
}

ir::CollectIRNodesWithoutTensor(sche_block->body, [&](const Expr* x) {
const ir::Load* load_expr = x->As<ir::Load>();
if (load_expr != nullptr) {
const ir::Tensor t = load_expr->tensor.as_tensor_ref();
sche_block->read_buffers.emplace_back(
ir::BufferRange(t->buffer, IndicesToVars(load_expr->indices)));
return false;
}
const ir::Store* store_expr = x->As<ir::Store>();
if (store_expr != nullptr) {
const ir::Tensor t = store_expr->tensor.as_tensor_ref();
sche_block->write_buffers.emplace_back(
ir::BufferRange(t->buffer, IndicesToVars(store_expr->indices)));
return false;
}
return false;
});
ir::ir_utils::CollectIRNodesWithoutTensor(
sche_block->body, [&](const Expr* x) {
const ir::Load* load_expr = x->As<ir::Load>();
if (load_expr != nullptr) {
const ir::Tensor t = load_expr->tensor.as_tensor_ref();
sche_block->read_buffers.emplace_back(
ir::BufferRange(t->buffer, IndicesToVars(load_expr->indices)));
return false;
}
const ir::Store* store_expr = x->As<ir::Store>();
if (store_expr != nullptr) {
const ir::Tensor t = store_expr->tensor.as_tensor_ref();
sche_block->write_buffers.emplace_back(
ir::BufferRange(t->buffer, IndicesToVars(store_expr->indices)));
return false;
}
return false;
});
}

bool ContainsNodeType(ir::Expr expr,
const std::unordered_set<ir::IrNodeTy>& node_types) {
std::set<ir::Expr> collection =
ir::CollectIRNodesWithoutTensor(expr, [&](const Expr* x) {
ir::ir_utils::CollectIRNodesWithoutTensor(expr, [&](const Expr* x) {
return node_types.find(x->node_type()) != node_types.end();
});
return !collection.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ bool IsSpatialLoop(const ir::For* for_node) {
const auto& loop_var = for_node->loop_var;
// collect cases where the loop_var used in one of reduce axis in underneath
// ScheduleBlock
auto used_for_reduce_axis = ir::CollectIRNodesWithoutTensor(
auto used_for_reduce_axis = ir::ir_utils::CollectIRNodesWithoutTensor(
for_node->body, [&loop_var](const Expr* x) {
const auto* block_realize = x->As<ir::ScheduleBlockRealize>();
if (!block_realize) return false;
Expand All @@ -46,7 +46,7 @@ bool IsSpatialLoop(const ir::For* for_node) {
const ir::Expr& binding = block_realize->iter_values[i];
if (iter_var->is_reduce_axis ||
iter_var->name.substr(0, 6) == "reduce") {
auto used_exprs = ir::CollectIRNodesWithoutTensor(
auto used_exprs = ir::ir_utils::CollectIRNodesWithoutTensor(
binding, [&loop_var](const Expr* x) {
const ir::_Var_* var = x->As<ir::_Var_>();
if (var &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
ir::Expr root = ir_sch->GetRootBlock(sche_block_realize_expr);

// Check the schedule block to be inlined is not a reduce tensor.
std::set<ir::Expr> find_store = ir::CollectIRNodesWithoutTensor(
std::set<ir::Expr> find_store = ir::ir_utils::CollectIRNodesWithoutTensor(
compute_body, [&](const Expr* x) { return x->As<ir::Store>(); });
if (find_store.size() != 1UL) {
return false;
Expand All @@ -76,17 +76,19 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
}

// Check this schedule block is the only writer of the tensor.
find_store = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->As<ir::Store>() &&
(x->As<ir::Store>()->tensor).as_tensor_ref()->name == tensor->name;
});
find_store =
ir::ir_utils::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->As<ir::Store>() &&
(x->As<ir::Store>()->tensor).as_tensor_ref()->name ==
tensor->name;
});
if (find_store.size() != 1UL) {
return false;
}
// Check there is no overlap between the buffers the schedule block reads and
// writes.
std::set<ir::Expr> find_load =
ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) {
std::set<ir::Expr> find_load = ir::ir_utils::CollectIRNodesWithoutTensor(
compute_body, [&](const Expr* x) {
return x->As<ir::Load>() && x->As<ir::Load>()->tensor == tensor_expr;
});
if (!find_load.empty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const {
return false;
};

auto find_target_exprs = ir::CollectIRNodesWithoutTensor(
auto find_target_exprs = ir::ir_utils::CollectIRNodesWithoutTensor(
schedule_block->body,
[&has_reduce_iter, &has_nonserial_loop](const Expr* x) {
return has_reduce_iter(x) || has_nonserial_loop(x);
Expand Down
7 changes: 3 additions & 4 deletions paddle/cinn/auto_schedule/search_space/search_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,10 @@ bool SearchStateEqual::operator()(const SearchState& lhs,
// compare exprs size firstly
if (lhs_exprs.size() != rhs_exprs.size()) return false;

// compare every expr one by one with ir::IrEqualVisitor
// compare every expr one by one with ir::ir_utils::IrEqualVisitor
for (int i = 0; i < lhs_exprs.size(); ++i) {
ir::IrEqualVisitor compartor(
/*allow_name_suffix_diff=*/true); // ignore suffix difference in name
if (!compartor.Compare(lhs_exprs[i], rhs_exprs[i])) return false;
if (!ir::ir_utils::IRCompare(lhs_exprs[i], rhs_exprs[i], true))
return false;
}
return true;
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/auto_schedule/search_space/search_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ struct SearchStateHash {
size_t operator()(const SearchState& s) const;
};

// SearchStateHash equal functor, use ir::IrEqualVisitor to compare their AST
// struct and fields
// SearchStateHash equal functor, use ir::ir_utils::IrEqualVisitor to compare
// their AST struct and fields
struct SearchStateEqual {
bool operator()(const SearchState& lhs, const SearchState& rhs) const;
};
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/backends/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ using cinn::common::float16;
const char *kCKeywordRestrict = "__restrict__";

void CodeGenC::Compile(const ir::Module &module, const Outputs &outputs) {
ir::IrVerify(Expr(module));
ir::ir_utils::IrVerify(Expr(module));

if (!outputs.c_header_name.empty()) {
auto source = Compile(module, OutputKind::CHeader);
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/backends/codegen_cuda_dev.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ std::string CodeGenCUDA_Dev::Compile(const ir::Module &module, bool for_nvrtc) {

void CodeGenCUDA_Dev::Compile(const ir::Module &module,
const Outputs &outputs) {
ir::IrVerify(Expr(module));
ir::ir_utils::IrVerify(Expr(module));

CodeGenC::inline_builtin_codes_ = false;
if (!outputs.c_header_name.empty()) {
Expand Down Expand Up @@ -90,7 +90,7 @@ std::vector<Expr> CodeGenCUDA_Dev::GenerateBufferAliasExprs(
temp_buffers.end());
// prepare temp buffer alias
std::vector<Expr> buffer_alias;
auto tensors = ir::CollectIRNodes(op->body, [&](const Expr *x) {
auto tensors = ir::ir_utils::CollectIRNodes(op->body, [&](const Expr *x) {
return x->as_tensor() && x->as_tensor()->buffer.defined() &&
temp_buffer_set.count(x->as_tensor()->buffer);
});
Expand Down
41 changes: 30 additions & 11 deletions paddle/cinn/backends/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ using CompilationStatus = hlir::framework::CompilationStatus;
static constexpr int DebugLogMaxLen = 30000;

void CompilationInfoDumper::DumpLoweredFuncByGroupIndex(
const ir::LoweredFunc& lowered_func, const int gidx) {
const ir::LoweredFunc& lowered_func, const int gidx, const int device_id) {
if (FLAGS_cinn_dump_group_lowered_func.empty() ||
lowered_func.get() == nullptr) {
return;
Expand All @@ -54,34 +54,42 @@ void CompilationInfoDumper::DumpLoweredFuncByGroupIndex(
content << lowered_func;
Dump(FLAGS_cinn_dump_group_lowered_func,
gidx,
device_id,
"lowered_function.txt",
content.str());
}

void CompilationInfoDumper::DumpSourceCodeByGroupIndex(
const std::string& source_code, const int gidx) {
const std::string& source_code, const int gidx, const int device_id) {
if (FLAGS_cinn_dump_group_source_code.empty()) {
return;
}
Dump(FLAGS_cinn_dump_group_source_code, gidx, "source_code.cu", source_code);
Dump(FLAGS_cinn_dump_group_source_code,
gidx,
device_id,
"source_code.cu",
source_code);
}

void CompilationInfoDumper::DumpPtxCodeByGroupIndex(
const std::string& source_ptx, const int gidx) {
const std::string& source_ptx, const int gidx, const int device_id) {
if (FLAGS_cinn_dump_group_ptx.empty()) {
return;
}
Dump(FLAGS_cinn_dump_group_ptx, gidx, "source_ptx.ptx", source_ptx);
Dump(
FLAGS_cinn_dump_group_ptx, gidx, device_id, "source_ptx.ptx", source_ptx);
}

void CompilationInfoDumper::DumpInstructionByGroupIndex(
const std::unique_ptr<cinn::hlir::framework::Instruction>& instr,
const int gidx) {
const int gidx,
const int device_id) {
if (FLAGS_cinn_dump_group_instruction.empty() || instr.get() == nullptr) {
return;
}
Dump(FLAGS_cinn_dump_group_instruction,
gidx,
device_id,
"instruction.txt",
instr->DumpInstruction());
}
Expand All @@ -99,6 +107,7 @@ void CompilationInfoDumper::DumpLoweredFunc() {
}
Dump(FLAGS_cinn_dump_group_lowered_func,
idx,
device_id_,
"lowered_function.txt",
content.str());
}
Expand All @@ -115,7 +124,11 @@ void CompilationInfoDumper::DumpSourceCode() {
} else {
dump_str = "[No source code generated]\n\n" + info_.Message(idx);
}
Dump(FLAGS_cinn_dump_group_source_code, idx, "source_code.cu", dump_str);
Dump(FLAGS_cinn_dump_group_source_code,
idx,
device_id_,
"source_code.cu",
dump_str);
}
}

Expand All @@ -130,7 +143,8 @@ void CompilationInfoDumper::DumpPtxCode() {
} else {
dump_str = "[No source ptxs generated]\n\n" + info_.Message(idx);
}
Dump(FLAGS_cinn_dump_group_ptx, idx, "source_ptx.ptx", dump_str);
Dump(
FLAGS_cinn_dump_group_ptx, idx, device_id_, "source_ptx.ptx", dump_str);
}
}

Expand All @@ -145,16 +159,21 @@ void CompilationInfoDumper::DumpInstruction() {
} else {
dump_str = "[No instruction generated]\n\n" + info_.Message(idx);
}
Dump(FLAGS_cinn_dump_group_instruction, idx, "instruction.txt", dump_str);
Dump(FLAGS_cinn_dump_group_instruction,
idx,
device_id_,
"instruction.txt",
dump_str);
}
}

void CompilationInfoDumper::Dump(const std::string& base_path,
const int idx,
const int device_id,
const std::string& file_name,
const std::string& content) {
auto dump_path =
utils::StringFormat("%s/fusion_group_%d", base_path.c_str(), idx);
auto dump_path = utils::StringFormat(
"%s/device_%d/fusion_group_%d", base_path.c_str(), device_id, idx);
if (!hlir::framework::MakeDirectory(
dump_path, S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) {
LOG(WARNING) << "Failed to make directory: \"" << dump_path
Expand Down
19 changes: 13 additions & 6 deletions paddle/cinn/backends/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,28 @@ namespace backends {
*/
class CompilationInfoDumper {
public:
explicit CompilationInfoDumper(const hlir::framework::CompilationResult& info)
: info_(info) {
explicit CompilationInfoDumper(const hlir::framework::CompilationResult& info,
const int device_id)
: info_(info), device_id_(device_id) {
DumpLoweredFunc();
DumpSourceCode();
DumpPtxCode();
DumpInstruction();
}

static void DumpLoweredFuncByGroupIndex(const ir::LoweredFunc& lowered_func,
const int gidx);
const int gidx,
const int device_id);
static void DumpSourceCodeByGroupIndex(const std::string& source_code,
const int gidx);
const int gidx,
const int device_id);
static void DumpPtxCodeByGroupIndex(const std::string& source_ptx,
const int gidx);
const int gidx,
const int device_id);
static void DumpInstructionByGroupIndex(
const std::unique_ptr<cinn::hlir::framework::Instruction>& instr,
const int gidx);
const int gidx,
const int device_id);

private:
void DumpLoweredFunc();
Expand All @@ -68,10 +73,12 @@ class CompilationInfoDumper {
void DumpInstruction();
static void Dump(const std::string& base_path,
const int idx,
const int device_id,
const std::string& file_name,
const std::string& content);

const hlir::framework::CompilationResult& info_;
const int device_id_;
};

class SourceCodePrint {
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/backends/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Call *op) {
llvm::Value *CodeGenLLVM::Visit(const ir::_Module_ *op) {
{
Expr body_to_verify(&Reference(op));
ir::IrVerify(body_to_verify);
ir::ir_utils::IrVerify(body_to_verify);
}

for (auto &fn : op->functions) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/backends/llvm/codegen_x86.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ void CodeGenX86::CreateParallelLaunch(Expr body, int num_task) {
llvm::Function::PrivateLinkage,
"__parallel_lambda",
m_);
std::vector<std::string> vars = ir::CollectUndefinedVars(&body);
std::vector<std::string> vars = ir::ir_utils::CollectUndefinedVars(&body);
uint64_t nbytes;
auto* data = PackVars(vars, &nbytes);

Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/common/arithmatic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ GiNaC::ex ExprToGinacConverter::BuildHelper(ir::Expr expr) {

GiNaC::ex ExprToGinacConverter::operator()(Expr expr) {
// TODO(Superjomn) Replace this with common::IsPureMath(
auto complex_nodes = CollectIRNodes(expr, [](const Expr* n) {
auto complex_nodes = ir::ir_utils::CollectIRNodes(expr, [](const Expr* n) {
return n->As<Block>() || //
n->As<PolyFor>() || //
n->As<EQ>() || //
Expand Down Expand Up @@ -262,7 +262,7 @@ bool IsPureMath(Expr expr) {
IrNodeTy ::Minus,
});

auto complex_nodes = ir::CollectIRNodes(expr, [&](const Expr* n) {
auto complex_nodes = ir::ir_utils::CollectIRNodes(expr, [&](const Expr* n) {
return !valid_node_tys.count(n->node_type());
});
#ifdef CINN_DEBUG
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/common/cas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1868,7 +1868,7 @@ bool IsExprCasCompatible(Expr expr) {
return expr->As<Add>() || expr->As<Sub>() || expr->As<Mul>() ||
expr->As<Div>();
};
return ir::CollectIRNodes(expr, teller).empty();
return ir::ir_utils::CollectIRNodes(expr, teller).empty();
}

// Partially divide a by b. e.g. (2x+y)/2 => x + y/2
Expand Down
Loading

0 comments on commit b731d39

Please sign in to comment.