Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MKLDNN] Use MKLDNNRun (#16772)
Browse files Browse the repository at this point in the history
* Use MKLDNNRun

* Fix lint

* Run CI
  • Loading branch information
ZhennanQin authored and pengzhao-intel committed Nov 29, 2019
1 parent 32a9baa commit 5fb2916
Show file tree
Hide file tree
Showing 27 changed files with 150 additions and 185 deletions.
2 changes: 1 addition & 1 deletion src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {
ptr_->Reorder2Default();

const mkldnn::memory *this_mem = GetMKLDNNData();
MKLDNNCopy(mem, this_mem);
MKLDNNMemoryCopy(mem, this_mem);
}

mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::desc &desc) {
Expand Down
4 changes: 2 additions & 2 deletions src/operator/leaky_relu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ static void LeakyReLUComputeExCPU(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), expected);
if (SupportMKLDNNLeakyRelu(param, inputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNLeakyReluForward(attrs, ctx, inputs[0], req[0], outputs[0]);
MKLDNNRun(MKLDNNLeakyReluForward, attrs, ctx, inputs[0], req[0], outputs[0]);
MKLDNN_OPCHECK_RUN(LeakyReLUCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
Expand All @@ -111,7 +111,7 @@ void LeakyReLUGradComputeExCPU(const nnvm::NodeAttrs& attrs,
if (SupportMKLDNNLeakyRelu(param, inputs[0])) {
std::vector<NDArray> in_data{inputs[0], inputs[1]};
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNLeakyReluBackward(attrs, ctx, in_data, req[0], outputs[0]);
MKLDNNRun(MKLDNNLeakyReluBackward, attrs, ctx, in_data, req, outputs);
MKLDNN_OPCHECK_RUN(LeakyReLUGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
Expand Down
7 changes: 2 additions & 5 deletions src/operator/nn/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ static void ActivationComputeExCPU(const nnvm::NodeAttrs& attrs,
CHECK_EQ(outputs.size(), 1U);
if (SupportMKLDNNAct(param, inputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNActivationForward(attrs, ctx, inputs[0], req[0], outputs[0]);
MKLDNNRun(MKLDNNActivationForward, attrs, ctx, inputs[0], req[0], outputs[0]);
MKLDNN_OPCHECK_RUN(ActivationCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
Expand All @@ -118,10 +118,7 @@ void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), activation::GradNumInputs(param.act_type));
if (SupportMKLDNNAct(param, inputs[0])) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
// XXX: for y = relu(x), y is passed as "in_data" to Backward()
const bool relu = param.act_type == activation::kReLU;
MKLDNNActivationBackward(attrs, ctx, inputs.at(0), relu ? inputs.at(1) : inputs.at(2), req[0],
outputs[0]);
MKLDNNRun(MKLDNNActivationBackward, attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(ActivationGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
Expand Down
31 changes: 2 additions & 29 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -394,17 +394,11 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 5U);
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);

if (SupportMKLDNNBN(inputs[0], param)) {
std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);
std::vector<NDArray> aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end());

if (inputs[0].dtype() == mshadow::kFloat32) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNBatchNormForward<float>(ctx, param, in_data, req, outputs, aux_states);
MKLDNNRun(MKLDNNBatchNormForward<float>, attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(BatchNormCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
}
FallBackCompute(BatchNormCompute<cpu>, attrs, ctx, inputs, req, outputs);
}
Expand All @@ -414,33 +408,12 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 8U);
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);

mxnet::TShape shape = inputs[0].shape();

if (SupportMKLDNNBN(inputs[0], param)) {
std::vector<NDArray> out_grad(1);
std::vector<NDArray> out_data(3);
std::vector<NDArray> in_data(3);
std::vector<NDArray> aux_states(2);
out_grad[0] = inputs[0];
out_data[batchnorm::kMean] = inputs[1];
out_data[batchnorm::kVar] = inputs[2];
in_data[batchnorm::kData] = inputs[3];
in_data[batchnorm::kGamma] = inputs[4];
in_data[batchnorm::kBeta] = inputs[5];
aux_states[batchnorm::kMovingMean] = inputs[6];
aux_states[batchnorm::kMovingVar] = inputs[7];
const std::vector<NDArray> &in_grad = outputs;

if (inputs[0].dtype() == mshadow::kFloat32) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNBatchNormBackward<float>(ctx, param, out_grad, in_data,
out_data, req, in_grad, aux_states);
MKLDNNRun(MKLDNNBatchNormBackward<float>, attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(BatchNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
}
FallBackCompute(BatchNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
}
Expand Down
4 changes: 2 additions & 2 deletions src/operator/nn/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs,
#if MXNET_USE_MKLDNN == 1
} else if (SupportMKLDNNConcat(inputs)) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNConcatForward(attrs, op_ctx, inputs, req, outputs);
MKLDNNRun(MKLDNNConcatForward, attrs, op_ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(ConcatCompute<cpu>, attrs, op_ctx, inputs, req, outputs);
} else if (common::ContainsOnlyStorage(inputs, kDefaultStorage)) {
FallBackCompute(ConcatCompute<cpu>, attrs, op_ctx, inputs, req, outputs);
Expand All @@ -288,7 +288,7 @@ static void ConcatGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& outputs) {
if (SupportMKLDNNConcat(inputs)) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNConcatBackward(attrs, ctx, inputs, req, outputs);
MKLDNNRun(MKLDNNConcatBackward, attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(ConcatGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
Expand Down
4 changes: 2 additions & 2 deletions src/operator/nn/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs,
common::ContainsOnlyStorage(outputs, kDefaultStorage)) {
if (SupportMKLDNNFC(inputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNFCForward(attrs, ctx, inputs, req, outputs);
MKLDNNRun(MKLDNNFCForward, attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(FullyConnectedCompute<cpu>, attrs, ctx, inputs, req,
outputs);
} else {
Expand Down Expand Up @@ -152,7 +152,7 @@ void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs,
bool mkldnn_fc_backward_enable = false;
if (mkldnn_fc_backward_enable && SupportMKLDNNFC(inputs[0])) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNFCBackward(attrs, ctx, inputs, req, outputs);
MKLDNNRun(MKLDNNFCBackward, attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(FullyConnectedGradCompute<cpu>, attrs, ctx, inputs, req,
outputs);
return;
Expand Down
10 changes: 2 additions & 8 deletions src/operator/nn/lrn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,10 @@ void LRNComputeExCPU(const nnvm::NodeAttrs &attrs,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
const LRNParam &param = nnvm::get<LRNParam>(attrs.parsed);
if (SupportMKLDNN(inputs[0])) {
// We only need to test one output array.
MKLDNN_OPCHECK_INIT(false, 1, inputs, outputs);
MKLDNNLRNForward(ctx, param, inputs[0], req[0], outputs[0]);
MKLDNNRun(MKLDNNLRNForward, attrs, ctx, inputs[0], req[0], outputs[0]);
MKLDNN_OPCHECK_RUN(LRNCompute<cpu>, attrs, ctx, inputs, req, outputs);
// Copy outputs[1] from opcheck reference as backward check needs it.
MKLDNN_OPCHECK_COPY_RESULT(outputs, std::vector<size_t>{1});
Expand All @@ -128,14 +127,9 @@ void LRNGradComputeExCPU(const nnvm::NodeAttrs &attrs,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
const LRNParam &param = nnvm::get<LRNParam>(attrs.parsed);
const NDArray &out_grad = inputs[0];
const NDArray &in_data = inputs[1];
const NDArray &in_grad = outputs[0];

if (SupportMKLDNN(inputs[0])) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNLRNBackward(ctx, param, out_grad, in_data, req[0], in_grad);
MKLDNNRun(MKLDNNLRNBackward, attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(LRNGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
Expand Down
52 changes: 20 additions & 32 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,8 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
MKLDNNActParam param_;
param_.alg = GetMKLDNNActAlgo(param);

NDArray in_buffer = in_data;
const NDArray& in_buffer = in_data;
MKLDNNStream *stream = MKLDNNStream::Get();

if (in_data.IsView() && in_data.IsMKLDNNData())
in_buffer = in_data.Reorder2Default();

auto input_mem = in_buffer.GetMKLDNNData();
MKLDNNActForward &fwd = GetActForward(param_, ctx, in_buffer, *input_mem);
auto out_mem_t = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_desc(), req, &in_buffer);
Expand Down Expand Up @@ -235,22 +230,18 @@ static inline MKLDNNActBackward &GetActBackward(const MKLDNNActParam &param,

// For backward relu activation, it's okay to pass "out_data" as "in_data" to this
// function, since the computation only involes non-zeros.
void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &out_grad, const NDArray &in_data,
const OpReqType &req, const NDArray &in_grad) {
if (req == kNullOp) {
void MKLDNNActivationBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs, const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
if (req[0] == kNullOp) {
return;
}

NDArray out_buffer = out_grad;
if (out_grad.IsView() && out_grad.IsMKLDNNData())
out_buffer = out_grad.Reorder2Default();

NDArray in_buffer = in_data;
if (in_data.IsView() && in_data.IsMKLDNNData())
in_buffer = in_data.Reorder2Default();

const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
// XXX: for y = relu(x), y is passed as "in_data" to Backward()
const bool relu = param.act_type == activation::kReLU;
const NDArray &out_buffer = inputs[0];
const NDArray &in_buffer = relu ? inputs[1] : inputs[2];
const NDArray &in_grad = outputs[0];
MKLDNNActParam param_;
param_.alg = GetMKLDNNActAlgo(param);
TmpMemMgr::Get()->Init(ctx.requested[activation::kTempSpace]);
Expand All @@ -264,7 +255,7 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem);
MKLDNNStream *stream = MKLDNNStream::Get();
mkldnn_output_t diff_src_memory =
CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req);
CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req[0]);
mkldnn_args_map_t args = {
{ MKLDNN_ARG_SRC, *input_mem },
{ MKLDNN_ARG_DIFF_DST, *diff_dst_memory },
Expand All @@ -278,19 +269,16 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray>& inputs,
const OpReqType &req,
const NDArray &output) {
if (req == kNullOp) {
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
if (req[0] == kNullOp) {
return;
}
CHECK_GE(inputs.size(), 2U);
NDArray out_buffer = inputs[0];
if (inputs[0].IsView() && inputs[0].IsMKLDNNData())
out_buffer = inputs[0].Reorder2Default();

NDArray in_buffer = inputs[1];
if (inputs[1].IsView() && inputs[1].IsMKLDNNData())
in_buffer = inputs[1].Reorder2Default();
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
const NDArray& out_buffer = inputs[0];
const NDArray& in_buffer = inputs[1];
const NDArray &output = outputs[0];

const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
MKLDNNActParam param_;
Expand All @@ -308,7 +296,7 @@ void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs& attrs,
GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem);
MKLDNNStream *stream = MKLDNNStream::Get();
mkldnn_output_t diff_src_memory =
CreateMKLDNNMem(output, bwd.bwd_pd.diff_src_desc(), req);
CreateMKLDNNMem(output, bwd.bwd_pd.diff_src_desc(), req[0]);
mkldnn_args_map_t args = {
{ MKLDNN_ARG_SRC, *input_mem },
{ MKLDNN_ARG_DIFF_DST, *diff_dst_memory },
Expand Down
15 changes: 14 additions & 1 deletion src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ enum OutDataOp {
};

typedef std::pair<OutDataOp, mkldnn::memory *> mkldnn_output_t;
void MKLDNNCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem);
void MKLDNNMemoryCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem);

/*
* Here we want to get MKLDNN memory whose desc is exactly the same as
Expand Down Expand Up @@ -688,6 +688,19 @@ void MKLDNNRun(mxnet::FComputeEx fn,
const std::vector<mxnet::OpReqType> &req,
const std::vector<mxnet::NDArray> &outputs_);

using FComputeExUnary = std::function<void (const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const NDArray& input,
const OpReqType& req,
const NDArray& output)>;

void MKLDNNRun(FComputeExUnary fn,
const nnvm::NodeAttrs &attrs,
const mxnet::OpContext &ctx,
const mxnet::NDArray &inputs_,
const mxnet::OpReqType &req,
const mxnet::NDArray &outputs_);

} // namespace mxnet
#endif
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BASE_INL_H_
19 changes: 17 additions & 2 deletions src/operator/nn/mkldnn/mkldnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ mkldnn::memory *TmpMemMgr::Alloc(const mkldnn::memory::desc &md) {
}
}

void MKLDNNCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem) {
void MKLDNNMemoryCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem) {
MKLDNNStream *stream = MKLDNNStream::Get();
mkldnn::memory::desc from_desc = mem.get_desc();
mkldnn::memory::desc this_desc = this_mem->get_desc();
Expand Down Expand Up @@ -227,7 +227,7 @@ void CommitOutput(const NDArray &arr, const mkldnn_output_t &res) {
auto mem = arr.GetMKLDNNData(res.second->get_desc());
if (mem == nullptr) {
auto tmp_memory = TmpMemMgr::Get()->Alloc(target_pd);
MKLDNNCopy(*res_memory, tmp_memory);
MKLDNNMemoryCopy(*res_memory, tmp_memory);
res_memory = tmp_memory;
mem = arr.GetMKLDNNData();
}
Expand Down Expand Up @@ -606,6 +606,21 @@ void MKLDNNRun(mxnet::FComputeEx fn,
}
}

void MKLDNNRun(FComputeExUnary fn,
const nnvm::NodeAttrs &attrs,
const mxnet::OpContext &ctx,
const mxnet::NDArray &input,
const mxnet::OpReqType &req,
const mxnet::NDArray &output) {
auto mkldnn_input = input;
if (input.IsView() && input.IsMKLDNNData()) {
mkldnn_input = input.Reorder2Default();
fn(attrs, ctx, mkldnn_input, req, output);
} else {
fn(attrs, ctx, input, req, output);
}
}

} // namespace mxnet

#endif
Loading

0 comments on commit 5fb2916

Please sign in to comment.