Skip to content

Commit

Permalink
[CINN][New Hardware Update] Replace nv target
Browse files Browse the repository at this point in the history
* replace DefaultNVGPUTarget with DefaultDeviceTarget
  • Loading branch information
DongBaiYue committed May 17, 2024
1 parent 2188b4a commit 01dd7f9
Show file tree
Hide file tree
Showing 10 changed files with 15 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ static int GetSharedSize(const cinn::dialect::ir::OpNode& op_node) {
lane = inshape[idx];
}
// int max_num_threads =
// cinn::common::DefaultNVGPUTarget().max_num_threads();
// cinn::common::DefaultDeviceTarget().max_num_threads();
int max_num_threads = 1000;
if (lane > max_num_threads / 2) {
return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ int GetSharedSize(::pir::Operation* op) {
lane = inshape[idx];
}
// int max_num_threads =
// cinn::common::DefaultNVGPUTarget().max_num_threads(); todo(phlrain): get
// gpu max threads
// cinn::common::DefaultDeviceTarget().max_num_threads();
// todo(phlrain): get gpu max threads
int max_num_threads = 2048;
if (lane > max_num_threads / 2) {
return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void FusionOpAnalysis::PreCompileGroup() {
}
// Build and trigger compilaion cache.
VLOG(4) << "Parallel Pre-Compile for Group with size: " << groups.size();
PirCompiler pir_compiler(cinn::common::DefaultNVGPUTarget());
PirCompiler pir_compiler(cinn::common::DefaultDeviceTarget());
pir_compiler.Build(groups);
}
} // namespace cinn::dialect::ir::details
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ std::vector<pir::Value> GetBlockOutsideInput(
std::unordered_map<OpLoweringGroupPtr,
std::unordered_map<std::string, pir::Attribute>>
CompileGroupAsOpAttribute(const std::vector<OpLoweringGroupPtr>& group_list) {
PirCompiler pir_compiler(cinn::common::DefaultNVGPUTarget());
PirCompiler pir_compiler(cinn::common::DefaultDeviceTarget());
auto fn_ptr_res = pir_compiler.Build(group_list);

std::unordered_map<OpLoweringGroupPtr,
Expand All @@ -85,7 +85,7 @@ std::unordered_map<std::string, ::pir::Attribute> GetJitKernelAttr(
hlir::framework::pir::FusionInfo fusion_info(*group);
return CompilationCache::Instance().GetKernelInfo(fusion_info);
} else {
PirCompiler pir_compiler(cinn::common::DefaultNVGPUTarget());
PirCompiler pir_compiler(cinn::common::DefaultDeviceTarget());
return pir_compiler.Build({group})[0];
}
};
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/op_lowering_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, // NOLINT
// If the number of current device SM is smaller than the number of SM
// required by Warp Reduce, the performance of Warp Reduce is better.
// Otherwise, use Block Reduce.
auto max_num_threads = cinn::common::DefaultNVGPUTarget().max_num_threads();
auto max_num_threads = cinn::common::DefaultDeviceTarget().max_num_threads();
int need_reduce_last_count = 1;
for (int i = 0; i < inshape.size(); i++) {
if (find(axes.begin(), axes.end(), i) == axes.end()) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/pir/op_lowering_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, // NOLINT
// If the number of current device SM is smaller than the number of SM
// required by Warp Reduce, the performance of Warp Reduce is better.
// Otherwise, use Block Reduce.
auto max_num_threads = cinn::common::DefaultNVGPUTarget().max_num_threads();
auto max_num_threads = cinn::common::DefaultDeviceTarget().max_num_threads();
int need_reduce_last_count = 1;
for (int i = 0; i < inshape.size(); i++) {
if (find(axes.begin(), axes.end(), i) == axes.end()) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/op/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
reduce_tmp_out.as_tensor_ref(),
tmp_out.as_tensor_ref(),
out.as_tensor_ref(),
cinn::common::DefaultNVGPUTarget());
cinn::common::DefaultDeviceTarget());

std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
Expand All @@ -279,7 +279,7 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
reduce_tmp_out.as_tensor_ref(),
tmp_out.as_tensor_ref(),
out.as_tensor_ref(),
cinn::common::DefaultNVGPUTarget());
cinn::common::DefaultDeviceTarget());

std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
Expand Down
6 changes: 3 additions & 3 deletions paddle/cinn/hlir/pe/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ std::vector<ir::Tensor> TwoStepBlockReduceInternal(
// If the number of current device SM is smaller than the number of SM
// required by Warp Reduce, the performance of Warp Reduce is better.
// Otherwise, use Block Reduce.
auto max_num_threads = cinn::common::DefaultNVGPUTarget().max_num_threads();
auto max_num_threads = cinn::common::DefaultDeviceTarget().max_num_threads();
int need_reduce_last_count = 1;
for (int i = 0; i < A->shape.size(); i++) {
if (find(axes.begin(), axes.end(), i) == axes.end()) {
Expand All @@ -851,9 +851,9 @@ std::vector<ir::Tensor> TwoStepBlockReduceInternal(
int warp_reduce_need_sm_count =
ceil((need_reduce_last_count * 32) /
static_cast<float>(
cinn::common::DefaultNVGPUTarget().get_max_threads_per_sm()));
cinn::common::DefaultDeviceTarget().get_max_threads_per_sm()));
// Set Num_max_threads to 32 is Warp Reduce
if (cinn::common::DefaultNVGPUTarget().get_multi_processor_count() <
if (cinn::common::DefaultDeviceTarget().get_multi_processor_count() <
warp_reduce_need_sm_count) {
max_num_threads = 32;
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void TileTactic::Init(ScheduleContext* context) {
};
auto GetTreeReduceSize = [&](const ir::Expr& total_rb_extent) -> int64_t {
const int64_t max_num_threads =
common::DefaultNVGPUTarget().max_num_threads();
cinn::common::DefaultDeviceTarget().max_num_threads();
int64_t nums_thread_per_block = max_num_threads;
if (total_rb_extent.is_constant()) {
int64_t extent = static_cast<int64_t>(total_rb_extent.get_constant());
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/optim/map_extern_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ void DealWithIntrinsicsImpl(common::NVGPUArch, ir::Call *node, Expr *expr) {
}

std::string extern_func =
hlir::GetExternFuncName(cinn::common::DefaultNVGPUTarget(), dtype, name);
hlir::GetExternFuncName(cinn::common::DefaultDeviceTarget(), dtype, name);
*expr = lang::CallExtern(extern_func, node->read_args, node->attrs);
}

Expand Down

0 comments on commit 01dd7f9

Please sign in to comment.