Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN][New Hardware Update] rename IRCudaSchedule #64318

Merged
merged 1 commit into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/op/contrib/gather_nd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ std::shared_ptr<framework::OpStrategy> StrategyForGatherNd(
},
[&](common::ARMArch) { CINN_NOT_IMPLEMENTED; },
[&](common::NVGPUArch) {
pe::IRCudaScheduleInjective(ir_sch, output_shapes.front(), target);
pe::IRGpuScheduleInjective(ir_sch, output_shapes.front(), target);
},
});
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/op/contrib/repeat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ std::shared_ptr<framework::OpStrategy> StrategyForRepeat(
},
[&](common::ARMArch) { CINN_NOT_IMPLEMENTED; },
[&](common::NVGPUArch) {
pe::IRCudaScheduleInjective(ir_sch, output_shapes.front(), target);
pe::IRGpuScheduleInjective(ir_sch, output_shapes.front(), target);
},
});
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/op/contrib/resize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ std::shared_ptr<framework::OpStrategy> StrategyForResize(
},
[&](common::ARMArch) { CINN_NOT_IMPLEMENTED; },
[&](common::NVGPUArch) {
pe::IRCudaScheduleInjective(ir_sch, output_shapes.front(), target);
pe::IRGpuScheduleInjective(ir_sch, output_shapes.front(), target);
},
});
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/op/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ std::shared_ptr<OpStrategy> StrategyForConv2d(
// gen, this code is to be removed.
if (conv_type != "forward") {
CHECK_EQ(vec_ast.size(), 1);
pe::IRCudaScheduleInjective(ir_sch, output_shapes.front(), target);
pe::IRGpuScheduleInjective(ir_sch, output_shapes.front(), target);
std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = CINNValuePack{res};
Expand Down
58 changes: 29 additions & 29 deletions paddle/cinn/hlir/op/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
Expr out = vec_tensor[0];
Expr tmp_out = vec_tensor[1];

VLOG(3) << "Do IRCudaScheduleBlockReduceInternal Schedule!";
pe::IRCudaScheduleBlockReduceInternal(
VLOG(3) << "Do IRGpuScheduleBlockReduceInternal Schedule!";
pe::IRGpuScheduleBlockReduceInternal(
ir_sch, tmp_out.as_tensor_ref(), out.as_tensor_ref(), target);

std::vector<CINNValue> res{
Expand All @@ -240,12 +240,12 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
Expr tmp_out = vec_tensor[1];
Expr reduce_tmp_out = vec_tensor[2];

VLOG(3) << "Do IRCudaScheduleBlockReduce Schedule!";
pe::IRCudaScheduleBlockReduce(ir_sch,
reduce_tmp_out.as_tensor_ref(),
tmp_out.as_tensor_ref(),
out.as_tensor_ref(),
target);
VLOG(3) << "Do IRGpuScheduleBlockReduce Schedule!";
pe::IRGpuScheduleBlockReduce(ir_sch,
reduce_tmp_out.as_tensor_ref(),
tmp_out.as_tensor_ref(),
out.as_tensor_ref(),
target);

std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
Expand All @@ -257,13 +257,13 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
Expr reduce_tmp_out = vec_tensor[2];
Expr reshape = vec_tensor[3];

VLOG(3) << "Do IRCudaTwoStepReduceSchedule Schedule!";
pe::IRCudaTwoStepReduceSchedule(ir_sch,
reshape.as_tensor_ref(),
reduce_tmp_out.as_tensor_ref(),
tmp_out.as_tensor_ref(),
out.as_tensor_ref(),
cinn::common::DefaultNVGPUTarget());
VLOG(3) << "Do IRGpuTwoStepReduceSchedule Schedule!";
pe::IRGpuTwoStepReduceSchedule(ir_sch,
reshape.as_tensor_ref(),
reduce_tmp_out.as_tensor_ref(),
tmp_out.as_tensor_ref(),
out.as_tensor_ref(),
cinn::common::DefaultNVGPUTarget());

std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
Expand All @@ -274,12 +274,12 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
Expr tmp_out = vec_tensor[1];
Expr reduce_tmp_out = vec_tensor[2];

VLOG(3) << "Do IRCudaScheduleBlockReduce Schedule!";
pe::IRCudaScheduleBlockReduce(ir_sch,
reduce_tmp_out.as_tensor_ref(),
tmp_out.as_tensor_ref(),
out.as_tensor_ref(),
cinn::common::DefaultNVGPUTarget());
VLOG(3) << "Do IRGpuScheduleBlockReduce Schedule!";
pe::IRGpuScheduleBlockReduce(ir_sch,
reduce_tmp_out.as_tensor_ref(),
tmp_out.as_tensor_ref(),
out.as_tensor_ref(),
cinn::common::DefaultNVGPUTarget());

std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
Expand All @@ -292,8 +292,8 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
CHECK_EQ(vec_tensor.size(), 1);
Expr reduce_out = vec_tensor[0];

VLOG(3) << "Do IRCudaScheduleReduce Schedule!";
pe::IRCudaScheduleReduce(
VLOG(3) << "Do IRGpuScheduleReduce Schedule!";
pe::IRGpuScheduleReduce(
ir_sch,
reduce_out.as_tensor_ref(),
inputs[0]->shape.size() - reduce_axes.back() - 1,
Expand All @@ -308,12 +308,12 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
Expr reduce_internal = vec_tensor[1];
Expr reduce_reshape = vec_tensor[2];

VLOG(3) << "Do IRCudaScheduleBlockShuffleReduce Schedule!";
pe::IRCudaScheduleBlockShuffleReduce(ir_sch,
reduce_reshape.as_tensor_ref(),
reduce_internal.as_tensor_ref(),
reduce_out.as_tensor_ref(),
target);
VLOG(3) << "Do IRGpuScheduleBlockShuffleReduce Schedule!";
pe::IRGpuScheduleBlockShuffleReduce(ir_sch,
reduce_reshape.as_tensor_ref(),
reduce_internal.as_tensor_ref(),
reduce_out.as_tensor_ref(),
target);

std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/op/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ std::shared_ptr<OpStrategy> StrategyForMatMul(
<< "The input argument of matmul schedule is empty! Please check.\n";
CINNValuePack arg_pack = args[0];
std::vector<CINNValue> results =
pe::IRCudaScheduleMatMul(arg_pack, output_shape, target);
pe::IRGpuScheduleMatMul(arg_pack, output_shape, target);
*ret = CINNValuePack({results});
});

Expand Down Expand Up @@ -660,7 +660,7 @@ std::shared_ptr<OpStrategy> StrategyForMul(
<< "The input argument of matmul schedule is empty! Please check.\n";
CINNValuePack arg_pack = args[0];
std::vector<CINNValue> results =
pe::IRCudaScheduleMatMul(arg_pack, output_shape, target);
pe::IRGpuScheduleMatMul(arg_pack, output_shape, target);
*ret = CINNValuePack({results});
});

Expand Down
80 changes: 40 additions & 40 deletions paddle/cinn/hlir/pe/ir_schedule_pe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,10 @@ void IRScheduleInjectiveCPU(ir::IRSchedule &ir_sch, // NOLINT
<< ir_sch.GetModule().GetExprs().at(0);
}

void IRCudaScheduleInjective(ir::IRSchedule &ir_sch, // NOLINT
const std::vector<int> &output_shape,
const cinn::common::Target &target) {
VLOG(3) << "Begin IRCudaScheduleInjective ";
void IRGpuScheduleInjective(ir::IRSchedule &ir_sch, // NOLINT
const std::vector<int> &output_shape,
const cinn::common::Target &target) {
VLOG(3) << "Begin IRGpuScheduleInjective ";
auto all_blocks = ir_sch.GetAllBlocks();
auto loops = ir_sch.GetLoops(all_blocks[0]);
auto fused = ir_sch.Fuse(loops);
Expand All @@ -176,11 +176,11 @@ void IRCudaScheduleInjective(ir::IRSchedule &ir_sch, // NOLINT
} else {
ir_sch.Bind(fused, "threadIdx.x");
}
VLOG(3) << "After IRCudaScheduleInjective, new ir is : "
VLOG(3) << "After IRGpuScheduleInjective, new ir is : "
<< ir_sch.GetModule().GetExprs().at(0);
}

std::vector<cinn::common::CINNValue> IRCudaScheduleMatMul(
std::vector<cinn::common::CINNValue> IRGpuScheduleMatMul(
const cinn::common::CINNValuePack &arg_pack,
const std::vector<int> &output_shape,
const cinn::common::Target &target) {
Expand Down Expand Up @@ -359,11 +359,11 @@ void IRCudaSplitSchedule(ir::IRSchedule &ir_sch, // NOLINT
<< ir_sch.GetModule().GetExprs().at(0);
}

void IRCudaScheduleReduce(ir::IRSchedule &ir_sch, // NOLINT
ir::Tensor output,
int last_dimension_num,
const cinn::common::Target &target) {
VLOG(3) << "Before IRCudaScheduleReduce : "
void IRGpuScheduleReduce(ir::IRSchedule &ir_sch, // NOLINT
ir::Tensor output,
int last_dimension_num,
const cinn::common::Target &target) {
VLOG(3) << "Before IRGpuScheduleReduce : "
<< ir_sch.GetModule().GetExprs().at(0);
int parallel_thread_num = 1;
auto &output_shape = output->shape;
Expand Down Expand Up @@ -411,15 +411,15 @@ void IRCudaScheduleReduce(ir::IRSchedule &ir_sch, // NOLINT
auto loops = ir_sch.GetLoops(output->name);
ir_sch.Bind(loops[0], "blockIdx.x");
}
VLOG(3) << "After IRCudaScheduleReduce : "
VLOG(3) << "After IRGpuScheduleReduce : "
<< ir_sch.GetModule().GetExprs().at(0);
}

void IRCudaScheduleBlockReduceInternal(ir::IRSchedule &ir_sch, // NOLINT
ir::Tensor tmp_out,
ir::Tensor out,
const cinn::common::Target &target) {
VLOG(3) << "Before IRCudaScheduleBlockReduceInternal : "
void IRGpuScheduleBlockReduceInternal(ir::IRSchedule &ir_sch, // NOLINT
ir::Tensor tmp_out,
ir::Tensor out,
const cinn::common::Target &target) {
VLOG(3) << "Before IRGpuScheduleBlockReduceInternal : "
<< ir_sch.GetModule().GetExprs().at(0);
int fuse_times = ir_sch.GetLoops(tmp_out->name).size() - 2;
for (int idx = 0; idx < fuse_times; ++idx) {
Expand Down Expand Up @@ -509,16 +509,16 @@ void IRCudaScheduleBlockReduceInternal(ir::IRSchedule &ir_sch, // NOLINT
}
}

VLOG(3) << "After IRCudaScheduleBlockReduceInternal : "
VLOG(3) << "After IRGpuScheduleBlockReduceInternal : "
<< ir_sch.GetModule().GetExprs().at(0);
}

void IRCudaScheduleBlockReduce(ir::IRSchedule &ir_sch, // NOLINT
ir::Tensor reduce_tmp_out,
ir::Tensor tmp_out,
ir::Tensor out,
const cinn::common::Target &target) {
VLOG(3) << "Before IRCudaScheduleBlockReduce : "
void IRGpuScheduleBlockReduce(ir::IRSchedule &ir_sch, // NOLINT
ir::Tensor reduce_tmp_out,
ir::Tensor tmp_out,
ir::Tensor out,
const cinn::common::Target &target) {
VLOG(3) << "Before IRGpuScheduleBlockReduce : "
<< ir_sch.GetModule().GetExprs().at(0);
int tmp_put_shape_size_without_reduce = 0;
for (auto i : tmp_out->shape) {
Expand Down Expand Up @@ -659,16 +659,16 @@ void IRCudaScheduleBlockReduce(ir::IRSchedule &ir_sch, // NOLINT
}
}

VLOG(3) << "After IRCudaScheduleBlockReduce : "
VLOG(3) << "After IRGpuScheduleBlockReduce : "
<< ir_sch.GetModule().GetExprs().at(0);
}

void IRCudaScheduleBlockShuffleReduce(ir::IRSchedule &ir_sch, // NOLINT
ir::Tensor reshape,
ir::Tensor internal,
ir::Tensor reduce_out,
const cinn::common::Target &target) {
VLOG(3) << "Before IRCudaScheduleBlockShuffleReduce : "
void IRGpuScheduleBlockShuffleReduce(ir::IRSchedule &ir_sch, // NOLINT
ir::Tensor reshape,
ir::Tensor internal,
ir::Tensor reduce_out,
const cinn::common::Target &target) {
VLOG(3) << "Before IRGpuScheduleBlockShuffleReduce : "
<< ir_sch.GetModule().GetExprs().at(0);
// reshape compute inline
{
Expand Down Expand Up @@ -921,17 +921,17 @@ void IRCudaScheduleBlockShuffleReduce(ir::IRSchedule &ir_sch, // NOLINT
ir_sch.Unroll(r_loops.back());
}
}
VLOG(3) << "After IRCudaScheduleBlockShuffleReduce : "
VLOG(3) << "After IRGpuScheduleBlockShuffleReduce : "
<< ir_sch.GetModule().GetExprs().at(0);
}

void IRCudaTwoStepReduceSchedule(ir::IRSchedule &ir_sch, // NOLINT
ir::Tensor reshape,
ir::Tensor internal,
ir::Tensor tmp_out,
ir::Tensor out,
const cinn::common::Target &target) {
VLOG(3) << "Before IRCudaTwoStepReduceSchedule : "
void IRGpuTwoStepReduceSchedule(ir::IRSchedule &ir_sch, // NOLINT
ir::Tensor reshape,
ir::Tensor internal,
ir::Tensor tmp_out,
ir::Tensor out,
const cinn::common::Target &target) {
VLOG(3) << "Before IRGpuTwoStepReduceSchedule : "
<< ir_sch.GetModule().GetExprs().at(0);
// fuse axis
int fuse_times =
Expand Down Expand Up @@ -1038,7 +1038,7 @@ void IRCudaTwoStepReduceSchedule(ir::IRSchedule &ir_sch, // NOLINT
}
}
}
VLOG(3) << "After IRCudaTwoStepReduceSchedule : "
VLOG(3) << "After IRGpuTwoStepReduceSchedule : "
<< ir_sch.GetModule().GetExprs().at(0);
// ir_sch.SimpleComputeAt(ir_sch.GetBlock(tmp_out->name),
// ir_sch.GetLoops(out->name)[0]);
Expand Down
58 changes: 29 additions & 29 deletions paddle/cinn/hlir/pe/ir_schedule_pe.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ void IRScheduleInjectiveCPU(ir::IRSchedule &ir_sch, // NOLINT
const cinn::common::Target &target,
bool vectorizable = true);

void IRCudaScheduleInjective(ir::IRSchedule &ir_sch, // NOLINT
const std::vector<int> &output_shape,
const cinn::common::Target &target);
void IRGpuScheduleInjective(ir::IRSchedule &ir_sch, // NOLINT
const std::vector<int> &output_shape,
const cinn::common::Target &target);

std::vector<cinn::common::CINNValue> IRCudaScheduleMatMul(
std::vector<cinn::common::CINNValue> IRGpuScheduleMatMul(
const cinn::common::CINNValuePack &arg_pack,
const std::vector<int> &output_shape,
const cinn::common::Target &target);
Expand All @@ -66,34 +66,34 @@ void IRCudaSplitSchedule(ir::IRSchedule &ir_sch, // NOLINT
int axis,
const cinn::common::Target &target);

void IRCudaScheduleReduce(ir::IRSchedule &ir_sch, // NOLINT
ir::Tensor out,
int last_dimension_num,
const cinn::common::Target &target);

void IRCudaScheduleBlockReduce(ir::IRSchedule &ir_sch, // NOLINT
ir::Tensor reduce_tmp_out,
ir::Tensor tmp_out,
ir::Tensor out,
const cinn::common::Target &target);

void IRCudaScheduleBlockReduceInternal(ir::IRSchedule &ir_sch, // NOLINT
ir::Tensor tmp_out,
ir::Tensor out,
const cinn::common::Target &target);

void IRCudaScheduleBlockShuffleReduce(ir::IRSchedule &ir_sch, // NOLINT
ir::Tensor reshape,
ir::Tensor internal,
void IRGpuScheduleReduce(ir::IRSchedule &ir_sch, // NOLINT
ir::Tensor out,
int last_dimension_num,
const cinn::common::Target &target);

void IRGpuScheduleBlockReduce(ir::IRSchedule &ir_sch, // NOLINT
ir::Tensor reduce_tmp_out,
ir::Tensor tmp_out,
ir::Tensor out,
const cinn::common::Target &target);

void IRGpuScheduleBlockReduceInternal(ir::IRSchedule &ir_sch, // NOLINT
ir::Tensor tmp_out,
ir::Tensor out,
const cinn::common::Target &target);

void IRCudaTwoStepReduceSchedule(ir::IRSchedule &ir_sch, // NOLINT
ir::Tensor reshape,
ir::Tensor internal,
ir::Tensor tmp_out,
ir::Tensor out,
const cinn::common::Target &target);
void IRGpuScheduleBlockShuffleReduce(ir::IRSchedule &ir_sch, // NOLINT
ir::Tensor reshape,
ir::Tensor internal,
ir::Tensor out,
const cinn::common::Target &target);

void IRGpuTwoStepReduceSchedule(ir::IRSchedule &ir_sch, // NOLINT
ir::Tensor reshape,
ir::Tensor internal,
ir::Tensor tmp_out,
ir::Tensor out,
const cinn::common::Target &target);

void IRSoftmaxScheduleCPU(ir::IRSchedule &ir_sch, int axis = -1); // NOLINT

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ std::vector<ir::Expr> CalculateIndexCommonFactor(
// FLAGS_cinn_bucket_compile=1. However, some unit tests (e.g.
// test_resnet_cinn, test_instance_norm_op) are still running with the
// deprecated OpScheduler, and the ir::Expr will break this guarantee after
// IRCudaScheduleBlockReduce function. So we have to relax the restriction
// IRGpuScheduleBlockReduce function. So we have to relax the restriction
// here.
if (indexes[i].size() != indexes[0].size()) {
LOG(WARNING)
Expand Down