From 6c5e39b69772e1e71e32df21dd7dc167e58c7ce7 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Thu, 4 Aug 2022 01:06:27 -0700 Subject: [PATCH 1/2] enable inplace post-op sum in dnnl byoc --- src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 35 ++++++++++++++++++- .../contrib/dnnl/dnnl_tensor_requisite.h | 25 +++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index d019f4e811ed..9deab71b5102 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -83,6 +83,17 @@ class DNNLJSONRuntime : public JSONRuntimeBase { // Find proper dnnl::memory buffers std::unordered_map mem_args; for (const auto& kvp : arg_reqs) mem_args[kvp.first] = mem_solver(kvp.second); + + // skip the reorder if src==dst to enable inplace operation + if (prim.get_kind() == dnnl::primitive::kind::reorder) { + const auto& mem_src = mem_args.at(DNNL_ARG_SRC); + const auto& mem_dst = mem_args.at(DNNL_ARG_DST); + if ((mem_src.get_desc() == mem_dst.get_desc()) && + (mem_src.get_data_handle() == mem_dst.get_data_handle())) { + continue; + } + } + prim.execute(stream_, mem_args); } } @@ -845,12 +856,34 @@ class DNNLJSONRuntime : public JSONRuntimeBase { return TensorRequisite::AsIs(desc, eid).Backward(); } + bool IsIntermidate(const TensorRequisite& tr) { + auto eid = tr.eid(); + bool is_input = std::find(input_nodes_.begin(), input_nodes_.end(), eid) != input_nodes_.end(); + bool is_output = std::any_of(outputs_.begin(), outputs_.end(), + [eid](auto& output) { return output.id_ == eid; }); + if (is_input || is_output) { + return false; + } else { + return true; + } + } + /*! \brief Helper function to register primitive into execution queue */ void Submit(const dnnl::primitive& prim, const std::unordered_map& tr_args, const std::pair& inplace_conf = {}) { // Register all provided TR arguments std::unordered_map prim_arg_id; TensorRegistry::ActionQue post_prim_actions; + + // mark inplace tr + if (auto tr_in = inplace_conf.first) { + auto tr_out = tr_args.at(inplace_conf.second); + if (IsIntermidate(tr_in) && IsIntermidate(tr_out)) { + tensor_registry_.Register(tr_in, &net_); + tensor_registry_.MarkInplace(tr_out, tr_in); + } + } + for (const auto& kvp : tr_args) { const auto& key = kvp.first; const auto& tr = kvp.second; @@ -860,7 +893,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { prim_arg_id[key] = arg_id; } - // Simulate inplace primitive + // Simulate inplace primitive, the reorder with src==dst will be skipped in Run() if (auto tr = inplace_conf.first) { auto arg_id = tensor_registry_.Register(tr, &net_); auto dst_tr = tr_args.at(inplace_conf.second); diff --git a/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h b/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h index bad4bc10edec..e3867f27bc71 100644 --- a/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h +++ b/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h @@ -115,6 +115,23 @@ class TensorRequisite { /*! \brief return tensor desc */ dnnl::memory::desc desc() const { return t_desc_; } + Tid eid() const { + auto res = kUndefinedTid; + + if (!defined()) { + res = kUndefinedTid; + } else if (eid_ == kUndefinedTid) { + if (orig_) { + res = orig_->eid(); + } else { + res = kUndefinedTid; + } + } else { + res = eid_; + } + return res; + } + /*! \brief Make TR with backward dataflow */ TensorRequisite Backward() const { if (!defined()) return *this; @@ -587,6 +604,14 @@ class TensorRegistry { tmp_mem_collection_, tmp_mem_mapping_); } + void MarkInplace(const TensorRequisite& tr, const TensorRequisite& shared) { + const auto tr_id = tr.eid(); + ICHECK(tr_id != TensorRequisite::kUndefinedTid); + const auto shared_id = shared.eid(); + ICHECK(shared_id != TensorRequisite::kUndefinedTid); + eid2idx_tmp_[tr_id] = eid2idx_tmp_[shared_id]; + } + private: ArgId RegisterReinterpret(ArgId src_ar, const dnnl::memory::desc& desc) { switch (src_ar.flag_) { From 65b415e670a2b2cf824d4d17ad739295d2f1af11 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Thu, 4 Aug 2022 01:06:48 -0700 Subject: [PATCH 2/2] add inplace post-op sum test --- tests/python/contrib/test_dnnl.py | 39 ++++++++++++++----------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index 8de8bd9ce687..c4adc9785c19 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -450,13 +450,6 @@ def get_layer_norm(x_shape=(1, 49, 64), dtype="float32"): return out, dic, param_lst -def get_conv2d_bias_sum_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"): - conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, dtype=dtype) - sum_data = relay.const(np.random.randint(x_shape).astype(dtype)) - conv2d_bias_sum = relay.add(sum_data, conv2d_bias) - return relay.nn.relu(conv2d_bias_sum), dic, param_lst - - def get_conv3d( x_shape=(1, 32, 8, 8, 8), k_shape=(16, 32, 3, 3, 3), @@ -799,7 +792,7 @@ def test_conv2d_bias_sum_relu(run_module, dtype="float32"): x_shape = (1, 32, 8, 8) k_shape = (16, 32, 3, 3) - def get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape, dtype="float32"): + def get_conv2d_bn_sum_relu(x_shape, k_shape, dtype="float32"): out, dic, param_lst = get_conv2d_bias(x_shape=x_shape, k_shape=k_shape, dtype=dtype) beta = relay.const(np.zeros(k_shape[0]).astype(dtype)) gamma = relay.const(np.ones(k_shape[0]).astype(dtype)) @@ -816,22 +809,24 @@ def get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape, dtype="float32"): scale=True, epsilon=1e-5, ) - sum_data = relay.var("data1", shape=sum_shape, dtype=dtype) - out = relay.add(out, sum_data) - dic["data1"] = sum_shape - param_lst += ["data1"] + sum_in = relay.var("sum_in", shape=x_shape, dtype=dtype) + kernel = relay.const(np.random.randint(0, 1, k_shape).astype(dtype)) + conv_sum = relay.nn.conv2d( + sum_in, + kernel, + channels=k_shape[0], + kernel_size=k_shape[2:4], + groups=1, + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + ) + # sum over two conv2d outputs to meet inplace condition + out = relay.add(out, conv_sum) + dic["sum_in"] = x_shape return relay.nn.relu(out), dic, param_lst - conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu( - x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype - ) - conv2d_bn_sum_relu = tvm.IRModule.from_expr(conv2d_bn_sum_relu) - config = conv2d_bn_sum_relu, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - - conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu( - x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype - ) + conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, dtype=dtype) conv2d_bn_sum_relu = tvm.IRModule.from_expr(conv2d_bn_sum_relu) config = conv2d_bn_sum_relu, dic, param_lst run_and_verify_func(config, run_module=run_module, dtype=dtype)