From 62f9b1d29ae25fbdeb425bfc600c5dac7c23f694 Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Fri, 14 Apr 2023 15:12:29 +0800 Subject: [PATCH] [Tensorflow] Fix conv2d_transpose for NHWC layout (#14546) * [Tensorflow] Fix conv2d_transpose for NHWC layout If "data_format" == "NHWC", the kernel_layout should be "HWOI" rather than "HWIO". * remove deed code * add test cases * Update test_forward.py * Update test_forward.py * Update tensorflow_ops.py * Update tensorflow_ops.py --- python/tvm/relay/frontend/tensorflow_ops.py | 4 ++-- .../frontend/tensorflow/test_forward.py | 21 ++++++++++++++++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index 6b3f144619dd..27374fad1a94 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -464,8 +464,8 @@ def _impl(inputs, attr, params, mod): if opname == "conv": attr["kernel_layout"] = "HWIO" if attr["data_format"] == "NHWC" else "OIHW" elif opname == "conv_transpose": - # conv_transpose in TVM has weights be IOHW for NCHW - attr["kernel_layout"] = "HWIO" if attr["data_format"] == "NHWC" else "IOHW" + # conv_transpose has weights be IOHW, because the attr["data_format"] always be NCHW + attr["kernel_layout"] = "IOHW" else: attr["kernel_layout"] = "HWOI" if attr["data_format"] == "NHWC" else "OIHW" diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 703df799423f..bd966fa71ccc 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -742,7 +742,16 @@ def test_forward_convolution(): "NCHW", [1, 1, 8, 8], ) - + _test_convolution( + "conv_transpose", + [4, 19, 8, 8], + [2, 2, 66, 19], + [1, 1], + [2, 2], + "VALID", + "NCHW", + [4, 66, 16, 16], + ) _test_convolution("conv", [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], "SAME", "NHWC") _test_convolution("conv", [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], "VALID", "NHWC") _test_convolution("conv", [4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], "SAME", "NHWC") @@ -917,6 +926,16 @@ def test_forward_convolution(): [4, 8, 8, 176], add_shapes_to_graph_def=False, ) + _test_convolution( + "conv_transpose", + [4, 8, 8, 19], + [2, 2, 66, 19], + [1, 1], + [2, 2], + "VALID", + "NHWC", + [4, 16, 16, 66], + ) # Explicit padding if package_version.parse(tf.VERSION) >= package_version.parse("2.4.1"): _test_convolution(