Skip to content

Commit

Permalink
opt op schedule (PaddlePaddle#436)
Browse files Browse the repository at this point in the history
* opt op schedule
  • Loading branch information
wenming2014 authored Sep 15, 2021
1 parent 40ca611 commit 5759f3e
Show file tree
Hide file tree
Showing 12 changed files with 163 additions and 32 deletions.
57 changes: 48 additions & 9 deletions cinn/backends/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,21 @@ llvm::Value *CodeGenLLVM::Visit(const ir::_Var_ *op) {
return result;
}

void CodeGenLLVM::Scalarize(const Expr &e, std::function<void(int i, llvm::Value *v)> flambda) {
if (const ir::Ramp *ramp = e.As<ir::Ramp>()) {
for (int i = 0; i < ramp->type().lanes(); ++i) {
Expr offset = ramp->base + (ramp->stride * i);
VLOG(3) << "offset: " << offset;
flambda(i, Visit(&offset));
}
} else {
llvm::Value *value = Visit(&e);
for (int i = 0; i < e->type().lanes(); ++i) {
flambda(i, b_->CreateExtractElement(value, i));
}
}
}

llvm::Value *CodeGenLLVM::Visit(const ir::Load *op) {
llvm::Value *array{nullptr};
bool is_alias{false};
Expand Down Expand Up @@ -731,14 +746,25 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Load *op) {
return load_inst;
} else { // vector load
Expr dense_strided_ramp = detail::StridedRampBase(op->index(), 1);
llvm::Value *buffer = Visit(&op->tensor);
if (dense_strided_ramp.defined()) {
CHECK(op->type().is_vector());

llvm::Value *buffer = Visit(&op->tensor);
return DenseVectorLoad(op);
} else {
LOG(FATAL) << "unsupported Ramp index " << op->index();
}
// scalarize load
Type type = op->type();
int alignment = type.bits() / 8;
llvm::Value *ret = llvm::UndefValue::get(CinnTypeToLLVMType(type, m_, true));
auto flambda = [&](int i, llvm::Value *index) {
auto *ptr = CreateBufferPtr(type.ElementOf(), buffer, index);
llvm::LoadInst *load_inst = b_->CreateAlignedLoad(ptr, llvm::Align(alignment), "load_vec");
ret = b_->CreateInsertElement(ret, load_inst, ll_const_int32(i));
if (auto *load_tensor = op->tensor.as_tensor()) {
AddTbaaMetadata(load_inst, load_tensor->name, op->index());
}
};
Scalarize(op->index(), flambda);
return ret;
}
}

Expand Down Expand Up @@ -784,13 +810,13 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Store *op) {
Expr dense_strided_ramp = detail::StridedRampBase(op->index(), 1);
auto ramp_expr = op->index();
auto *ramp = index.As<ir::Ramp>();
auto *buffer = Visit(&op->tensor);
auto *value = Visit(&op->value);

if (dense_strided_ramp.defined()) { // stride 1
int total_lanes = op->type().lanes();
int step = naive_vec_alignment_ / op->type().ElementOf().bits();

auto *buffer = Visit(&op->tensor);
auto *value = Visit(&op->value);

// fit the total_lanes in native_lanes(split into multiple native steps)
for (int offset = 0; offset < total_lanes; offset += total_lanes) {
int lanes = total_lanes;
Expand All @@ -805,9 +831,22 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Store *op) {
AddTbaaMetadata(inst, op->tensor.as_tensor()->name, base);
return inst;
}
} else {
LOG(FATAL) << "unsupported Ramp index " << ramp_expr;
}
// scalarize store
Type type = op->type();
int alignment = type.bits() / 8;
llvm::Value *ret = llvm::UndefValue::get(CinnTypeToLLVMType(type, m_, true));
auto flambda = [&](int i, llvm::Value *index) {
auto *ptr = CreateBufferPtr(type.ElementOf(), buffer, index);
llvm::StoreInst *store_inst =
b_->CreateAlignedStore(b_->CreateExtractElement(value, i), ptr, llvm::Align(alignment), "store_vec");
ret = b_->CreateInsertElement(ret, store_inst, ll_const_int32(i));
if (auto *store_tensor = op->tensor.as_tensor()) {
AddTbaaMetadata(store_inst, store_tensor->name, op->index());
}
};
Scalarize(op->index(), flambda);
return ret;
}
return nullptr;
}
Expand Down
2 changes: 2 additions & 0 deletions cinn/backends/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ class CodeGenLLVM : public LLVMIRVisitor, public IrBuilderMixin<CodeGenLLVM> {

void InitTarget(const Target &target);

void Scalarize(const Expr &e, std::function<void(int i, llvm::Value *v)> flambda);

llvm::Module *m_;
llvm::IRBuilder<> *b_;
// Current function
Expand Down
28 changes: 13 additions & 15 deletions cinn/hlir/op/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -478,8 +478,8 @@ std::shared_ptr<OpStrategy> StrategyForConv2dNCHWc(const framework::NodeAttr &at
// A is input: [N, C_in_outer, H, W, C_in_inner], B is filter: [C_out, C_in_group_outer, filter_h, filter_w,
// C_in_group_inner]
std::string key;
VLOG(3) << "input shape: " << utils::Join(tensor_a->shape, ", ");
VLOG(3) << "weight shape: " << utils::Join(tensor_b->shape, ", ");
VLOG(3) << "input[" << utils::Join(tensor_a->shape, ", ") << "], weight shape["
<< utils::Join(tensor_b->shape, ", ") << "]";
out = pe::Conv2d_NCHWc(tensor_a,
tensor_b,
padding[0],
Expand Down Expand Up @@ -1165,9 +1165,9 @@ std::shared_ptr<OpStrategy> StrategyForPool2d(const framework::NodeAttr &attrs,
CHECK(input_pad.as_tensor());
stages[input_pad.as_tensor_ref()]->ComputeInline();
}
CHECK(Out.as_tensor());
ir::Tensor temp_out = Out.as_tensor_ref();
if (target.arch == Target::Arch::NVGPU) {
CHECK(Out.as_tensor());
ir::Tensor temp_out = Out.as_tensor_ref();
pe::PoolScheduleGPU(stages, temp_out, target);
arg_pack[arg_pack.size() - 2] = Expr(temp_out);
}
Expand Down Expand Up @@ -1519,10 +1519,6 @@ std::shared_ptr<OpStrategy> StrategyForSoftmax(const framework::NodeAttr &attrs,
}
std::vector<ir::Tensor> out;
bool use_mkldnn = false;
#ifdef CINN_WITH_MKLDNN
use_mkldnn = true;
#endif
use_mkldnn = use_mkldnn && (target.arch == Target::Arch::X86);
if (use_mkldnn) {
out = pe::SoftmaxMKLDNN(A, new_axis, UniqName("Softmax_mkldnn_output"));
} else {
Expand All @@ -1544,21 +1540,23 @@ std::shared_ptr<OpStrategy> StrategyForSoftmax(const framework::NodeAttr &attrs,
CINNValuePack arg_pack = args[0];
CHECK_EQ(arg_pack.size(), 3UL) << "The input tensor's size of softmax schedule is " << arg_pack.size()
<< "and it should be equal to 3! Please check.";
Expr out1 = arg_pack[0];
Expr out2 = arg_pack[1];
poly::StageMap stages = arg_pack[2];
CHECK(out1.as_tensor());
CHECK(out2.as_tensor());
ir::Tensor tensor_a = out1.as_tensor_ref();
ir::Tensor tensor_b = out2.as_tensor_ref();
if (target.arch == Target::Arch::NVGPU) {
Expr out1 = arg_pack[0];
Expr out2 = arg_pack[1];
poly::StageMap stages = arg_pack[2];
CHECK(out1.as_tensor());
CHECK(out2.as_tensor());
ir::Tensor tensor_a = out1.as_tensor_ref();
ir::Tensor tensor_b = out2.as_tensor_ref();
if (tensor_a->shape.size() > 1) {
stages[tensor_a]->Split(1, 2);
stages[tensor_a]->Bind(0, "blockIdx.x");
stages[tensor_a]->Bind(1, "threadIdx.x");
int shape_size = tensor_a->shape.size();
stages[tensor_b]->ComputeAt(stages[tensor_a], shape_size);
}
} else if (target.arch == Target::Arch::X86) {
pe::SoftmaxScheduleCPU(stages, tensor_a, tensor_b, axis);
}
*ret = arg_pack;
});
Expand Down
4 changes: 3 additions & 1 deletion cinn/hlir/op/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,9 @@ std::shared_ptr<OpStrategy> StrategyForLayoutTransform(const framework::NodeAttr
for (auto shape : tensor_out->shape) {
out_shape.push_back(shape.as_int32());
}

if (target.arch == Target::Arch::X86) {
pe::ScheduleInjectiveCPUFuse(stages[tensor_out], out_shape, target);
}
*ret = arg_pack;
});

Expand Down
2 changes: 1 addition & 1 deletion cinn/hlir/pass/alterlayout_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ TEST(conv_sigmoid_conv, conv_sigmoid_conv) {
}

TEST(conv_mul_conv, conv_mul_conv) {
Placeholder A(Float(32), {1, 3, 224, 224}, "A");
Placeholder A(Float(32), {3, 3, 224, 224}, "A");
Placeholder B(Float(32), {64, 3, 7, 7}, "B");
Placeholder C(Float(32), {1, 64, 112, 112}, "C");
Placeholder D(Float(32), {64, 64, 7, 7}, "D");
Expand Down
53 changes: 53 additions & 0 deletions cinn/hlir/pe/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,39 @@ int GetBetterSplitFactor(int shape, int split_factor) {
return better_factor;
}

int GetVectorizeFactor(int shape, int split_factor) {
int better_factor = 1;
for (int i = split_factor; i > 1; i--) {
if (shape % i == 0) {
better_factor = i;
break;
}
}
return better_factor;
}

void ScheduleInjectiveCPUFuse(poly::Stage *stage, const std::vector<int> &output_shape, const common::Target &target) {
int dims = stage->n_out_dims();
int factor = GetBasicFactor(stage->tensor()->type(), target);
poly::Iterator fused = stage->axis(0);
if (dims >= 5) {
fused = stage->Fuse({0, 1, 2});
} else if (dims >= 3) {
fused = stage->Fuse({0, 1});
}
stage->Parallel(fused);
dims = stage->n_out_dims();
poly::Iterator lo;
poly::Iterator li;
int last_shape = stage->GetDimRange(dims - 1);
factor = GetVectorizeFactor(last_shape, factor);
std::tie(lo, li) = stage->Split(stage->axis(dims - 1), factor);
stage->Vectorize(li, factor);
if (dims == 1) {
stage->Parallel(0);
}
}

void ScheduleInjectiveCPU(poly::Stage *stage, const std::vector<int> &output_shape, const common::Target &target) {
int dims = stage->n_out_dims();
if (dims > 1) {
Expand Down Expand Up @@ -250,6 +283,25 @@ void MulScheduleCPU(poly::StageMap stages,
}
}

void SoftmaxScheduleCPU(poly::StageMap stage, const ir::Tensor &output, const ir::Tensor &temp, int axis) {
if (axis == -1) {
axis += output->shape.size();
}
poly::Iterator fused = stage[output]->axis(0);
stage[output]->Parallel(fused);
for (int i = 1; i < axis; i++) {
fused = stage[output]->Fuse(0, 1);
}
CHECK_GT(stage[output]->n_out_dims(), 1);
stage[temp]->ComputeAt(stage[output], 0);
}

void PoolScheduleCPU(poly::StageMap stages, const ir::Tensor &output, const common::Target &target) {
CHECK_GE(stages[output]->n_out_dims(), 2);
stages[output]->Fuse({0, 1});
stages[output]->Parallel(0);
}

void PoolScheduleGPU(poly::StageMap stages, ir::Tensor &output, const common::Target &target) {
CHECK_GE(stages[output]->axis_names().size(), 4);
stages[output]->Fuse({0, 1, 2, 3});
Expand All @@ -275,6 +327,7 @@ void GetConv2dFactors(std::unordered_map<std::string, int> *factors,
LoadSerialData(&params);
}
if (params.count(key)) {
VLOG(3) << "find saved param, key is: " << key;
CHECK(!params[key]["oc_bn"].empty());
CHECK(!params[key]["ic_bn"].empty());
CHECK(!params[key]["ow_bn"].empty());
Expand Down
4 changes: 4 additions & 0 deletions cinn/hlir/pe/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ int GetBetterSplitFactor(int shape, int split_factor);
int GetArrayPackingFactor(int shape, const Type &type, const common::Target &target);

void ScheduleInjectiveCPU(poly::Stage *stage, const std::vector<int> &output_shape, const common::Target &target);
void ScheduleInjectiveCPUFuse(poly::Stage *stage, const std::vector<int> &output_shape, const common::Target &target);

void MatmulScheduleCPU(poly::StageMap stage,
const ir::Tensor &output,
Expand All @@ -57,6 +58,8 @@ void MulScheduleCPU(poly::StageMap stage,
const ir::Tensor &input_tensor,
const common::Target &target);

void SoftmaxScheduleCPU(poly::StageMap stage, const ir::Tensor &output, const ir::Tensor &temp, int axis = -1);

void GetConv2dFactors(std::unordered_map<std::string, int> *factors,
int oc,
int ic,
Expand Down Expand Up @@ -86,6 +89,7 @@ void Conv2d_NCHWc_Schedule_CPU(poly::StageMap stages,
const std::string &key,
bool do_padding);

void PoolScheduleCPU(poly::StageMap stages, const ir::Tensor &output, const common::Target &target);
void PoolScheduleGPU(poly::StageMap stages, ir::Tensor &output, const common::Target &target);

void Conv2d_NCHWc_Schedule_CPU_Nofuse(poly::StageMap stages,
Expand Down
6 changes: 6 additions & 0 deletions cinn/optim/collect_undefined_vars.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ struct Mutator : public ir::IRMutator<> {
defined_vars.insert(var);
}

void ClearVar(const std::string& var) {
defined_vars.erase(var);
used_vars.erase(var);
}

void CollectVarUse(const std::string& var) {
used_vars.insert(var);
if (defined_vars.count(var) == 0) {
Expand All @@ -41,6 +46,7 @@ struct Mutator : public ir::IRMutator<> {
Visit(&node->min, &node->min);
Visit(&node->extent, &node->extent);
Visit(&node->body, &node->body);
ClearVar(op->loop_var->name);
}

void Visit(const ir::Load* op, Expr* expr) final {
Expand Down
9 changes: 9 additions & 0 deletions cinn/poly/stage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,15 @@ void Stage::Vectorize(const std::string &axis, int factor) {

void Stage::Vectorize(const Iterator &axis, int factor) { return Vectorize(axis.id, factor); }

void Stage::Parallel(const std::string &axis) {
auto dims = isl_get_dim_names(transformed_domain());
auto it = std::find(dims.begin(), dims.end(), axis);
CHECK(it != dims.end()) << "No dimension called " << axis;
Parallel(std::distance(dims.begin(), it));
}

void Stage::Parallel(const Iterator &axis) { return Parallel(axis.id); }

void Stage::Parallel(int level) {
CHECK_GE(level, 0);
AssertAxisIsNotLocked(level);
Expand Down
2 changes: 2 additions & 0 deletions cinn/poly/stage.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ class Stage : public Object {
* @param level
*/
void Parallel(int level);
void Parallel(const std::string& axis);
void Parallel(const Iterator& axis);

/**
* Unroll a for-loop.
Expand Down
19 changes: 17 additions & 2 deletions tests/benchmark/test_all_ops_default.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,21 @@ std::unordered_map<std::string, AttrType> attr_store_depthwise_conv2d = {{"paddi
{"dilation", dilation_depthwise_conv2d}};
TEST_DEFAULT1(depthwise_conv2d, depthwise_conv2d_nchw, type1, type7, attr_store_depthwise_conv2d)

// layout_transform
std::vector<std::vector<int>> shapes_layout_transform = {{512, 512, 3, 3}};
std::string src_layout = "OIHW";
std::string dst_layout = "OIHW16i16o";
std::unordered_map<std::string, AttrType> attr_store_layout_transform = {{"src_layout", src_layout},
{"dst_layout", dst_layout}};
TEST_DEFAULT1(layout_transform, layout_transform, type, type, attr_store_layout_transform)

std::vector<std::vector<int>> shapes_layout_transform1 = {{64, 3, 7, 7}};
std::string src_layout1 = "OIHW";
std::string dst_layout1 = "OIHW3i32o";
std::unordered_map<std::string, AttrType> attr_store_layout_transform1 = {{"src_layout", src_layout1},
{"dst_layout", dst_layout1}};
TEST_DEFAULT1(layout_transform, layout_transform1, type, type, attr_store_layout_transform1)

// pool2d
hlir::framework::NodeAttr attrs;
std::vector<int> kernel_size = {3, 3};
Expand All @@ -205,9 +220,9 @@ TEST_DEFAULT1(pool2d, pool2d1, type, type, attr_store_pool2d)

// softmax
std::vector<std::vector<int>> shapes_softmax = {{1024, 2048}};
TEST_DEFAULT(softmax, softmax, type, type6)
TEST_DEFAULT(softmax, softmax, type, type1)
std::vector<std::vector<int>> shapes_softmax1 = {{3, 1000}};
TEST_DEFAULT(softmax, softmax1, type, type6)
TEST_DEFAULT(softmax, softmax1, type, type1)

// sigmoid
std::vector<std::vector<int>> shapes_sigmoid = {{2, 672, 1, 1}};
Expand Down
9 changes: 5 additions & 4 deletions tools/paddle_benchmark/paddle_test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ def main():
predictor.zero_copy_run()

time1 = time.time()
for i in range(0, 500):
repeat = 10
for i in range(0, repeat):
predictor.zero_copy_run()
time2 = time.time()
total_inference_cost = (time2 - time1) * 1000 # total time cost(ms)
print("Average latency : {} ms".format(total_inference_cost / 500))
print("Average latency : {} ms".format(total_inference_cost / repeat))
output_names = predictor.get_output_names()
output_tensor = predictor.get_output_tensor(output_names[0])
output_data = output_tensor.copy_to_cpu()
Expand Down Expand Up @@ -66,8 +67,8 @@ def set_config(args):
config.switch_ir_optim(False)
#To test cpu backend, just uncomment the following 2 lines.
# config.switch_ir_optim(True)
#config.disable_gpu()
#config.enable_mkldnn()
# config.disable_gpu()
# config.enable_mkldnn()
return config


Expand Down

0 comments on commit 5759f3e

Please sign in to comment.