Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC] [DNNL] enable in-place post-op sum in dnnl json runtime #12371

Merged
merged 2 commits into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
// Find proper dnnl::memory buffers
std::unordered_map<int, dnnl::memory> mem_args;
for (const auto& kvp : arg_reqs) mem_args[kvp.first] = mem_solver(kvp.second);

// skip the reorder if src==dst to enable inplace operation
if (prim.get_kind() == dnnl::primitive::kind::reorder) {
const auto& mem_src = mem_args.at(DNNL_ARG_SRC);
const auto& mem_dst = mem_args.at(DNNL_ARG_DST);
if ((mem_src.get_desc() == mem_dst.get_desc()) &&
(mem_src.get_data_handle() == mem_dst.get_data_handle())) {
continue;
}
}

prim.execute(stream_, mem_args);
}
}
Expand Down Expand Up @@ -845,12 +856,34 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
return TensorRequisite::AsIs(desc, eid).Backward();
}

bool IsIntermidate(const TensorRequisite& tr) {
auto eid = tr.eid();
bool is_input = std::find(input_nodes_.begin(), input_nodes_.end(), eid) != input_nodes_.end();
bool is_output = std::any_of(outputs_.begin(), outputs_.end(),
[eid](auto& output) { return output.id_ == eid; });
if (is_input || is_output) {
return false;
} else {
return true;
}
}

/*! \brief Helper function to register primitive into execution queue */
void Submit(const dnnl::primitive& prim, const std::unordered_map<int, TensorRequisite>& tr_args,
const std::pair<TensorRequisite, int>& inplace_conf = {}) {
// Register all provided TR arguments
std::unordered_map<int, TensorRegistry::ArgId> prim_arg_id;
TensorRegistry::ActionQue post_prim_actions;

// mark inplace tr
if (auto tr_in = inplace_conf.first) {
auto tr_out = tr_args.at(inplace_conf.second);
if (IsIntermidate(tr_in) && IsIntermidate(tr_out)) {
tensor_registry_.Register(tr_in, &net_);
tensor_registry_.MarkInplace(tr_out, tr_in);
}
}

for (const auto& kvp : tr_args) {
const auto& key = kvp.first;
const auto& tr = kvp.second;
Expand All @@ -860,7 +893,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
prim_arg_id[key] = arg_id;
}

// Simulate inplace primitive
// Simulate inplace primitive, the reorder with src==dst will be skipped in Run()
if (auto tr = inplace_conf.first) {
auto arg_id = tensor_registry_.Register(tr, &net_);
auto dst_tr = tr_args.at(inplace_conf.second);
Expand Down
25 changes: 25 additions & 0 deletions src/runtime/contrib/dnnl/dnnl_tensor_requisite.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,23 @@ class TensorRequisite {
/*! \brief return tensor desc */
dnnl::memory::desc desc() const { return t_desc_; }

Tid eid() const {
auto res = kUndefinedTid;

if (!defined()) {
res = kUndefinedTid;
} else if (eid_ == kUndefinedTid) {
if (orig_) {
res = orig_->eid();
} else {
res = kUndefinedTid;
}
} else {
res = eid_;
}
return res;
}

/*! \brief Make TR with backward dataflow */
TensorRequisite Backward() const {
if (!defined()) return *this;
Expand Down Expand Up @@ -587,6 +604,14 @@ class TensorRegistry {
tmp_mem_collection_, tmp_mem_mapping_);
}

void MarkInplace(const TensorRequisite& tr, const TensorRequisite& shared) {
const auto tr_id = tr.eid();
ICHECK(tr_id != TensorRequisite::kUndefinedTid);
const auto shared_id = shared.eid();
ICHECK(shared_id != TensorRequisite::kUndefinedTid);
eid2idx_tmp_[tr_id] = eid2idx_tmp_[shared_id];
}

private:
ArgId RegisterReinterpret(ArgId src_ar, const dnnl::memory::desc& desc) {
switch (src_ar.flag_) {
Expand Down
39 changes: 17 additions & 22 deletions tests/python/contrib/test_dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,13 +450,6 @@ def get_layer_norm(x_shape=(1, 49, 64), dtype="float32"):
return out, dic, param_lst


def get_conv2d_bias_sum_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"):
conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, dtype=dtype)
sum_data = relay.const(np.random.randint(x_shape).astype(dtype))
conv2d_bias_sum = relay.add(sum_data, conv2d_bias)
return relay.nn.relu(conv2d_bias_sum), dic, param_lst


def get_conv3d(
x_shape=(1, 32, 8, 8, 8),
k_shape=(16, 32, 3, 3, 3),
Expand Down Expand Up @@ -799,7 +792,7 @@ def test_conv2d_bias_sum_relu(run_module, dtype="float32"):
x_shape = (1, 32, 8, 8)
k_shape = (16, 32, 3, 3)

def get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape, dtype="float32"):
def get_conv2d_bn_sum_relu(x_shape, k_shape, dtype="float32"):
out, dic, param_lst = get_conv2d_bias(x_shape=x_shape, k_shape=k_shape, dtype=dtype)
beta = relay.const(np.zeros(k_shape[0]).astype(dtype))
gamma = relay.const(np.ones(k_shape[0]).astype(dtype))
Expand All @@ -816,22 +809,24 @@ def get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape, dtype="float32"):
scale=True,
epsilon=1e-5,
)
sum_data = relay.var("data1", shape=sum_shape, dtype=dtype)
out = relay.add(out, sum_data)
dic["data1"] = sum_shape
param_lst += ["data1"]
sum_in = relay.var("sum_in", shape=x_shape, dtype=dtype)
kernel = relay.const(np.random.randint(0, 1, k_shape).astype(dtype))
conv_sum = relay.nn.conv2d(
sum_in,
kernel,
channels=k_shape[0],
kernel_size=k_shape[2:4],
groups=1,
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
)
# sum over two conv2d outputs to meet inplace condition
out = relay.add(out, conv_sum)
dic["sum_in"] = x_shape
return relay.nn.relu(out), dic, param_lst

conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(
x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype
)
conv2d_bn_sum_relu = tvm.IRModule.from_expr(conv2d_bn_sum_relu)
config = conv2d_bn_sum_relu, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)

conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(
x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype
)
conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, dtype=dtype)
conv2d_bn_sum_relu = tvm.IRModule.from_expr(conv2d_bn_sum_relu)
config = conv2d_bn_sum_relu, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)
Expand Down