Skip to content

Commit

Permalink
[BYOC] [DNNL] enable in-place post-op sum in dnnl json runtime (#12371)
Browse files Browse the repository at this point in the history
* enable inplace post-op sum in dnnl byoc

* add inplace post-op sum test
  • Loading branch information
yangulei authored Aug 12, 2022
1 parent de12486 commit e8de88e
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 23 deletions.
35 changes: 34 additions & 1 deletion src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
// Find proper dnnl::memory buffers
std::unordered_map<int, dnnl::memory> mem_args;
for (const auto& kvp : arg_reqs) mem_args[kvp.first] = mem_solver(kvp.second);

// skip the reorder if src==dst to enable inplace operation
if (prim.get_kind() == dnnl::primitive::kind::reorder) {
const auto& mem_src = mem_args.at(DNNL_ARG_SRC);
const auto& mem_dst = mem_args.at(DNNL_ARG_DST);
if ((mem_src.get_desc() == mem_dst.get_desc()) &&
(mem_src.get_data_handle() == mem_dst.get_data_handle())) {
continue;
}
}

prim.execute(stream_, mem_args);
}
}
Expand Down Expand Up @@ -845,12 +856,34 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
return TensorRequisite::AsIs(desc, eid).Backward();
}

bool IsIntermidate(const TensorRequisite& tr) {
auto eid = tr.eid();
bool is_input = std::find(input_nodes_.begin(), input_nodes_.end(), eid) != input_nodes_.end();
bool is_output = std::any_of(outputs_.begin(), outputs_.end(),
[eid](auto& output) { return output.id_ == eid; });
if (is_input || is_output) {
return false;
} else {
return true;
}
}

/*! \brief Helper function to register primitive into execution queue */
void Submit(const dnnl::primitive& prim, const std::unordered_map<int, TensorRequisite>& tr_args,
const std::pair<TensorRequisite, int>& inplace_conf = {}) {
// Register all provided TR arguments
std::unordered_map<int, TensorRegistry::ArgId> prim_arg_id;
TensorRegistry::ActionQue post_prim_actions;

// mark inplace tr
if (auto tr_in = inplace_conf.first) {
auto tr_out = tr_args.at(inplace_conf.second);
if (IsIntermidate(tr_in) && IsIntermidate(tr_out)) {
tensor_registry_.Register(tr_in, &net_);
tensor_registry_.MarkInplace(tr_out, tr_in);
}
}

for (const auto& kvp : tr_args) {
const auto& key = kvp.first;
const auto& tr = kvp.second;
Expand All @@ -860,7 +893,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
prim_arg_id[key] = arg_id;
}

// Simulate inplace primitive
// Simulate inplace primitive, the reorder with src==dst will be skipped in Run()
if (auto tr = inplace_conf.first) {
auto arg_id = tensor_registry_.Register(tr, &net_);
auto dst_tr = tr_args.at(inplace_conf.second);
Expand Down
25 changes: 25 additions & 0 deletions src/runtime/contrib/dnnl/dnnl_tensor_requisite.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,23 @@ class TensorRequisite {
/*! \brief return tensor desc */
dnnl::memory::desc desc() const { return t_desc_; }

Tid eid() const {
auto res = kUndefinedTid;

if (!defined()) {
res = kUndefinedTid;
} else if (eid_ == kUndefinedTid) {
if (orig_) {
res = orig_->eid();
} else {
res = kUndefinedTid;
}
} else {
res = eid_;
}
return res;
}

/*! \brief Make TR with backward dataflow */
TensorRequisite Backward() const {
if (!defined()) return *this;
Expand Down Expand Up @@ -587,6 +604,14 @@ class TensorRegistry {
tmp_mem_collection_, tmp_mem_mapping_);
}

void MarkInplace(const TensorRequisite& tr, const TensorRequisite& shared) {
const auto tr_id = tr.eid();
ICHECK(tr_id != TensorRequisite::kUndefinedTid);
const auto shared_id = shared.eid();
ICHECK(shared_id != TensorRequisite::kUndefinedTid);
eid2idx_tmp_[tr_id] = eid2idx_tmp_[shared_id];
}

private:
ArgId RegisterReinterpret(ArgId src_ar, const dnnl::memory::desc& desc) {
switch (src_ar.flag_) {
Expand Down
39 changes: 17 additions & 22 deletions tests/python/contrib/test_dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,13 +450,6 @@ def get_layer_norm(x_shape=(1, 49, 64), dtype="float32"):
return out, dic, param_lst


def get_conv2d_bias_sum_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"):
conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, dtype=dtype)
sum_data = relay.const(np.random.randint(x_shape).astype(dtype))
conv2d_bias_sum = relay.add(sum_data, conv2d_bias)
return relay.nn.relu(conv2d_bias_sum), dic, param_lst


def get_conv3d(
x_shape=(1, 32, 8, 8, 8),
k_shape=(16, 32, 3, 3, 3),
Expand Down Expand Up @@ -799,7 +792,7 @@ def test_conv2d_bias_sum_relu(run_module, dtype="float32"):
x_shape = (1, 32, 8, 8)
k_shape = (16, 32, 3, 3)

def get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape, dtype="float32"):
def get_conv2d_bn_sum_relu(x_shape, k_shape, dtype="float32"):
out, dic, param_lst = get_conv2d_bias(x_shape=x_shape, k_shape=k_shape, dtype=dtype)
beta = relay.const(np.zeros(k_shape[0]).astype(dtype))
gamma = relay.const(np.ones(k_shape[0]).astype(dtype))
Expand All @@ -816,22 +809,24 @@ def get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape, dtype="float32"):
scale=True,
epsilon=1e-5,
)
sum_data = relay.var("data1", shape=sum_shape, dtype=dtype)
out = relay.add(out, sum_data)
dic["data1"] = sum_shape
param_lst += ["data1"]
sum_in = relay.var("sum_in", shape=x_shape, dtype=dtype)
kernel = relay.const(np.random.randint(0, 1, k_shape).astype(dtype))
conv_sum = relay.nn.conv2d(
sum_in,
kernel,
channels=k_shape[0],
kernel_size=k_shape[2:4],
groups=1,
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
)
# sum over two conv2d outputs to meet inplace condition
out = relay.add(out, conv_sum)
dic["sum_in"] = x_shape
return relay.nn.relu(out), dic, param_lst

conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(
x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype
)
conv2d_bn_sum_relu = tvm.IRModule.from_expr(conv2d_bn_sum_relu)
config = conv2d_bn_sum_relu, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)

conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(
x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype
)
conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, dtype=dtype)
conv2d_bn_sum_relu = tvm.IRModule.from_expr(conv2d_bn_sum_relu)
config = conv2d_bn_sum_relu, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)
Expand Down

0 comments on commit e8de88e

Please sign in to comment.