Skip to content

Commit

Permalink
Fix transpose split optimize attr when opset >=13 (#1996)
Browse files Browse the repository at this point in the history
* add tranposeOptimizer split for opset >= 13

Signed-off-by: Deyu Huang <deyhuang@microsoft.com>

* add test

Signed-off-by: Deyu Huang <deyhuang@microsoft.com>

* fix comments

Signed-off-by: Deyu Huang <deyhuang@microsoft.com>
  • Loading branch information
hwangdeyu authored Jul 15, 2022
1 parent 9ce72be commit e896723
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 15 deletions.
51 changes: 38 additions & 13 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_transpose_with_split(self, input_shape, perm, inner_perm):
((1, -1), (1, 1710), (1710,), [1, 0]),
((3, 1, 1, 5, -1), (3, 1, 1, 5, 6), (3, 5, 6), [0, 2, 3, 4, 1]),
])
@check_opset_max_version(12, "split attribute changed to input in opset 13")
@check_opset_max_version(12, "split attribute changed to input since opset 13")
def test_transpose_with_split_dynamic_shape(self, input_shape, specific_input, output_shape, perm):
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
node2 = helper.make_node("Split", ["Y"], ["Z"], axis=1, split=[1], name="split")
Expand All @@ -162,6 +162,31 @@ def test_transpose_with_split_dynamic_shape(self, input_shape, specific_input, o
self.run_transpose_compare(["B"], {"X": np.random.randn(*specific_input).astype(np.float32)},
model_proto, remaining_transpose_num=0)

@parameterized.expand([
((3, 1, 1), (1, 1, 3), (1), [0, 2, 3, 1]),
((256, 1, 1), (1, 1, 256), (1), [0, 2, 3, 1])
])
@check_opset_min_version(13, "split attribute changed to input since opset 13")
def test_transpose_with_split_opset13(self, input_shape, output_shape, split_val, perm):
unsqueeze_axes = self._make_onnx_const(np.array([0], dtype=np.int64), "axes1")
unsqueeze = helper.make_node("Unsqueeze", ["X", "axes1"], ["Y"], name="unsqueeze")
trans = helper.make_node("Transpose", ["Y"], ["Z"], perm=perm, name="trans")
split_attr = self._make_onnx_const(np.array([split_val], dtype=np.int64), "split_attr")
split = helper.make_node("Split", ["Z", "split_attr"], ["A"], axis=0, name="split")
squeeze_axes = self._make_onnx_const(np.array([1], dtype=np.int64), "axes2")
squeeze = helper.make_node("Squeeze", ["A", "axes2"], ["B"], name="squeeze")

graph = helper.make_graph(
[unsqueeze_axes, unsqueeze, trans, split_attr, split, squeeze_axes, squeeze],
"test_transpose_with_split_opset13",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
[helper.make_tensor_value_info("B", TensorProto.FLOAT, output_shape)],
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_transpose_compare(["B"], {"X": np.random.randn(*input_shape).astype(np.float32)},
model_proto, remaining_transpose_num=0)

@parameterized.expand([
((2, 3, 4), [2, 0, 1], [1, 2, 0]),
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),
Expand Down Expand Up @@ -717,7 +742,7 @@ def test_transpose_sqrt(self, shape, perm_input, perm_output):
((1, 3, 4, 5), (4, 5, 3), [0, 2, 3, 1], [1, 2, 0]),
((1, 3, 4, 5, 6), (4, 5, 6, 3), [0, 2, 3, 4, 1], [1, 2, 3, 0]),
])
@check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13")
@check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13")
def test_transpose_with_squeeze1(self, input_shape, output_shape, perm, expected_perm):
# squeeze the first dim
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
Expand Down Expand Up @@ -768,7 +793,7 @@ def test_transpose_with_unsqueeze(self, input_shape, output_shape, perm, axes_va
((1, 3, 4, 5), (4, 5, 3), [0, 2, 3, 1], [1, 2, 0]),
((1, 3, 4, 5, 6), (4, 5, 6, 3), [0, 2, 3, 4, 1], [1, 2, 3, 0]),
])
@check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13")
@check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13")
def test_transpose_with_squeeze1_13(self, input_shape, output_shape, perm, expected_perm):
# squeeze the first dim
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
Expand All @@ -791,7 +816,7 @@ def test_transpose_with_squeeze1_13(self, input_shape, output_shape, perm, expec
((3, 4, 1, 5), (3, 5, 4), [0, 2, 3, 1], [0, 2, 1]),
((3, 4, 1, 5, 6), (3, 5, 6, 4), [0, 2, 3, 4, 1], [0, 2, 3, 1]),
])
@check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13")
@check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13")
def test_transpose_with_squeeze2(self, input_shape, output_shape, perm, expected_perm):
# squeeze the second dim
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
Expand All @@ -813,7 +838,7 @@ def test_transpose_with_squeeze2(self, input_shape, output_shape, perm, expected
((3, 4, 1, 5), (3, 5, 4), [0, 2, 3, 1], [0, 2, 1]),
((3, 4, 1, 5, 6), (3, 5, 6, 4), [0, 2, 3, 4, 1], [0, 2, 3, 1]),
])
@check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13")
@check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13")
def test_transpose_with_squeeze2_13(self, input_shape, output_shape, perm, expected_perm):
# squeeze the second dim
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
Expand All @@ -836,7 +861,7 @@ def test_transpose_with_squeeze2_13(self, input_shape, output_shape, perm, expec
((3, 1, 4, 5), (3, 4, 5), [0, 2, 3, 1]),
((3, 1, 4, 5, 6), (3, 4, 5, 6), [0, 2, 3, 4, 1]),
])
@check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13")
@check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13")
def test_transpose_with_squeeze3(self, input_shape, output_shape, perm):
# squeeze the last dim
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
Expand All @@ -857,7 +882,7 @@ def test_transpose_with_squeeze3(self, input_shape, output_shape, perm):
((3, 1, 4, 5), (3, 4, 5), [0, 2, 3, 1]),
((3, 1, 4, 5, 6), (3, 4, 5, 6), [0, 2, 3, 4, 1]),
])
@check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13")
@check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13")
def test_transpose_with_squeeze3_13(self, input_shape, output_shape, perm):
# squeeze the last dim
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
Expand All @@ -879,7 +904,7 @@ def test_transpose_with_squeeze3_13(self, input_shape, output_shape, perm):
((3, 1, 1, 5), (3, 5), [0, 2, 3, 1]),
((3, 1, 1, 5, 4), (3, 5, 4), [0, 2, 3, 4, 1]),
])
@check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13")
@check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13")
def test_transpose_with_squeeze4(self, input_shape, output_shape, perm):
# squeeze the two dims
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
Expand All @@ -900,7 +925,7 @@ def test_transpose_with_squeeze4(self, input_shape, output_shape, perm):
((3, 1, 1, 5), (3, 5), [0, 2, 3, 1]),
((3, 1, 1, 5, 4), (3, 5, 4), [0, 2, 3, 4, 1]),
])
@check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13")
@check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13")
def test_transpose_with_squeeze4_13(self, input_shape, output_shape, perm):
# squeeze the two dims
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
Expand Down Expand Up @@ -2156,7 +2181,7 @@ def test_const_fold_concat(self):
self.run_and_compare(["res"], {"inp": np.random.randn(6, 12).astype(np.float32)}, model_proto,
"Concat", 0)

@check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13")
@check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13")
def test_const_fold_unsqueeze_with_const(self):
shape = (6, 6)
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
Expand All @@ -2176,7 +2201,7 @@ def test_const_fold_unsqueeze_with_const(self):
self.run_and_compare(["res"], {"X": np.random.randn(1).astype(np.float32)}, model_proto,
"Unsqueeze", 0)

@check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13")
@check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13")
def test_const_fold_unsqueeze_with_const_13(self):
shape = (6, 6)
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
Expand Down Expand Up @@ -2254,7 +2279,7 @@ def test_const_fold_split_one(self):
self.run_and_compare(["out4"], {"inp": np.random.randn(2, 6, 1).astype(np.float32)}, model_proto,
"Split", 0)

@check_opset_min_version(13, "Split changed in opset 13")
@check_opset_min_version(13, "Split changed since opset 13")
def test_const_fold_split_const_splits_13(self):
shape = (2, 6, 1)
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
Expand All @@ -2277,7 +2302,7 @@ def test_const_fold_split_const_splits_13(self):
self.run_and_compare(["out4"], {"inp": np.random.randn(2, 3, 1).astype(np.float32)}, model_proto,
"Split", 0)

@check_opset_max_version(12, "Split changed in opset 13")
@check_opset_max_version(12, "Split changed since opset 13")
def test_const_fold_split_const_splits(self):
shape = (2, 6, 1)
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
Expand Down
12 changes: 10 additions & 2 deletions tf2onnx/optimizer/transpose_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,11 +671,19 @@ def _concat_handler(self, trans, node):

def _split_handler(self, trans, node):
# Todo: need handle cases where Split node has more than 1 outputs.
split = None
if self._g.opset >= 13 and len(node.input) > 1 and node.inputs[1].is_const():
# split is an input not attr since opset 13
split = node.inputs[1].get_tensor_value(as_list=True)
if self._handle_node_having_branches(trans, node):
perm = trans.get_attr_value("perm")
axis = node.get_attr_value("axis", 0)
new_axis = perm[axis]
node.set_attr("axis", new_axis)
if split:
new_axes_np = np.array(split, dtype=np.int64)
new_axes_const = self._g.make_const(utils.make_name(node.inputs[1].name), new_axes_np)
self._g.replace_inputs(node, [node.input[0], new_axes_const.output[0]])
return True
return False

Expand Down Expand Up @@ -747,7 +755,7 @@ def _calculate_new_attr(ori_perm, ori_squeeze_axes):
shape_after_trans = [input_shape[i] for i in ori_perm]
output_shape = [shape_after_trans[i] for i in range(n) if i not in ori_squeeze_axes]
# calculate new_perm
# after switch, the output shape should be same, using this condtion we can figure the new perm
# after switch, the output shape should be same, using this condition we can figure the new perm
shape_after_squeeze = [input_shape[i] for i in range(n) if i not in new_squeeze_axes]
new_perm = [shape_after_squeeze.index(i) for i in output_shape]

Expand All @@ -757,7 +765,7 @@ def _calculate_new_attr(ori_perm, ori_squeeze_axes):
return False

axes = None
# in opset 13, axes is an input not attr
# axes is an input not attr since opset 13
if node.get_attr("axes"):
axes = node.get_attr("axes").ints
if len(node.input) > 1 and node.inputs[1].is_const():
Expand Down

0 comments on commit e896723

Please sign in to comment.