Skip to content

Commit

Permalink
[Tensorflow] Fix conv2d_transpose for NHWC layout (#14546)
Browse files Browse the repository at this point in the history
* [Tensorflow] Fix conv2d_transpose for NHWC layout

If "data_format" == "NHWC", the kernel_layout should be "HWOI" rather than "HWIO".

* remove deed code

* add test cases

* Update test_forward.py

* Update test_forward.py

* Update tensorflow_ops.py

* Update tensorflow_ops.py
  • Loading branch information
jikechao authored Apr 14, 2023
1 parent 670d128 commit 62f9b1d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
4 changes: 2 additions & 2 deletions python/tvm/relay/frontend/tensorflow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,8 @@ def _impl(inputs, attr, params, mod):
if opname == "conv":
attr["kernel_layout"] = "HWIO" if attr["data_format"] == "NHWC" else "OIHW"
elif opname == "conv_transpose":
# conv_transpose in TVM has weights be IOHW for NCHW
attr["kernel_layout"] = "HWIO" if attr["data_format"] == "NHWC" else "IOHW"
# conv_transpose has weights be IOHW, because the attr["data_format"] always be NCHW
attr["kernel_layout"] = "IOHW"
else:
attr["kernel_layout"] = "HWOI" if attr["data_format"] == "NHWC" else "OIHW"

Expand Down
21 changes: 20 additions & 1 deletion tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,16 @@ def test_forward_convolution():
"NCHW",
[1, 1, 8, 8],
)

_test_convolution(
"conv_transpose",
[4, 19, 8, 8],
[2, 2, 66, 19],
[1, 1],
[2, 2],
"VALID",
"NCHW",
[4, 66, 16, 16],
)
_test_convolution("conv", [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], "SAME", "NHWC")
_test_convolution("conv", [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], "VALID", "NHWC")
_test_convolution("conv", [4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], "SAME", "NHWC")
Expand Down Expand Up @@ -917,6 +926,16 @@ def test_forward_convolution():
[4, 8, 8, 176],
add_shapes_to_graph_def=False,
)
_test_convolution(
"conv_transpose",
[4, 8, 8, 19],
[2, 2, 66, 19],
[1, 1],
[2, 2],
"VALID",
"NHWC",
[4, 16, 16, 66],
)
# Explicit padding
if package_version.parse(tf.VERSION) >= package_version.parse("2.4.1"):
_test_convolution(
Expand Down

0 comments on commit 62f9b1d

Please sign in to comment.