Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed group attribute in convolution op #2090

Merged
merged 11 commits into from
May 31, 2023
14 changes: 14 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6138,6 +6138,20 @@ def func(x):
x_val = make_xval([2, 3])
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@check_opset_min_version(11, "Pad")
def test_conv_unknown_kernel_channels(self):
x_shape = [2, 10, 3]
x_val = make_xval(x_shape)
kernel_shape = [4, 3, 5]
kernel_val = make_xval(kernel_shape)
pad_val = np.array([[0, 0], [0, 0], [0, 0]], np.int64)
def func(x, kernel, pad):
# Make kernel dimensions unknown
kernel = tf.pad(kernel, pad)
conv = tf.nn.conv1d(x, kernel, stride=[1], padding='VALID')
return tf.identity(conv, name='output')
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: kernel_val, _INPUT2: pad_val})

@check_tf_min_version("2.3.0")
@check_opset_min_version(16, "ScatterND")
@skip_tfjs("not supported in tfjs")
Expand Down
14 changes: 9 additions & 5 deletions tf2onnx/onnx_opset/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,9 @@ def conv_kernel_shape(ctx, node, input_idx, spatial=2):
# Get spatial part.
kernel_shape = kernel_shape[:spatial]

# Set new value and return it.
node.set_attr("kernel_shape", kernel_shape)
# Set attribute value only if all dimensions are known.
if all(d > 0 for d in kernel_shape):
node.set_attr("kernel_shape", kernel_shape)

return kernel_shape

Expand Down Expand Up @@ -379,11 +380,13 @@ def any_version(cls, opset, ctx, node, **kwargs):
data_format = str(node.attr["data_format"].s, encoding="utf8")
shape_dim = -1
if data_format == "NHWC":
shape_dim = ctx.get_shape(node.input[0])[3]
shape_dim = ctx.get_shape(node.input[0])[-1]
elif data_format == "NCHW":
shape_dim = ctx.get_shape(node.input[0])[1]
if shape_dim != -1:
groups = int(shape_dim / ctx.get_shape(node.input[1])[2])
filter_in_channels = ctx.get_shape(node.input[1])[-2]
if filter_in_channels != -1:
groups = shape_dim // filter_in_channels

node.set_attr("group", groups)

Expand Down Expand Up @@ -649,7 +652,8 @@ def version_1(cls, ctx, node, **kwargs):
raise ValueError("input channel must be positive")
k_output_channels = k_input_channels * k_channel_multiplier

node.set_attr("kernel_shape", [k_h, k_w])
if k_h > 0 and k_w > 0:
node.set_attr("kernel_shape", [k_h, k_w])
strides = conv_dims_attr(node, "strides")
dilations = conv_dims_attr(node, "dilations")
node.set_attr("group", k_input_channels)
Expand Down