Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed group attribute in convolution op #2090

Merged
merged 11 commits into from
May 31, 2023
Merged

Fixed group attribute in convolution op #2090

merged 11 commits into from
May 31, 2023

Conversation

javidcf
Copy link
Contributor

@javidcf javidcf commented Dec 1, 2022

Also minor change reading shape dimensions so it works for different dimensionalities.
Fixes #2084

…ape dimensions so it works for different dimensionalities.

Signed-off-by: Javier Dehesa <javidcf@gmail.com>
@fatcat-z
Copy link
Collaborator

The code LGTM, thanks for your contributions.

Could you please add a test (https://github.com/onnx/tensorflow-onnx/blob/main/tests/test_backend.py) to cover this?

@javidcf
Copy link
Contributor Author

javidcf commented Mar 1, 2023

I wrote this test case for this:

@check_opset_min_version(11, "Pad")
def test_conv_unknown_kernel_channels(self):
    x_shape = [2, 10, 3]
    x_val = make_xval(x_shape)
    kernel_shape = [4, 3, 5]
    kernel_val = make_xval(kernel_shape)
    pad_val = np.array([[0, 0], [0, 0], [0, 0]], np.int64)
    def func(x, kernel, pad):
        # Make kernel dimensions unknown
        kernel = tf.pad(kernel, pad)
        conv = tf.nn.conv1d(x, kernel, stride=[1], padding='VALID')
        return tf.identity(conv, name='output')
    self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: kernel_val, _INPUT2: pad_val})

Unfortunately, in the process of doing so I found another bug (already present before the fix) with convolution kernels with unknown shape. My fix solves the issue of the wrong value in the op group attribute, but attempting to run the above test function with ONNX (after my fix) results in this error:

[ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running Conv node. Name:'conv1d' Status Message: kernel_shape is not compatible with W shape. kernel_shape: {1,-1} W: {5,3,1,4}

The error can also be reproduced before my fix by simply adding the following line before the call to tf.nn.conv1d:

x = tf.pad(x, pad)

The problem is I think this is actually an onnxruntime issue. This line from onnxruntime seems to suggest the code is not properly checking that the dimensions are known (not -1), or maybe kernel_shape_specified should be false in that case...

Either way, I'm not sure if a different test should be added which somehow just checks the attribute value with a graph_validator without trying to run the operation (which I'm not sure if it's possible, since ops not used by the result get pruned) or leave the test as it is and keep this "on hold" until the other issue is resolved.

EDIT: Another possibility is that the kernel_shape attribute of the Conv operation should always be fully defined. The operator spec does not explicitly say whether partially defined shapes are allowed or not.
EDIT 2: Submitted issue microsoft/onnxruntime#14869 about it.
EDIT 3: Seems that -1 is indeed not allowed in kernel_shape, so tf2onnx should not generate the attribute (which is optional for convolution operations anyway) if any of its values is not known. I'll add that fix.

javidcf and others added 5 commits March 1, 2023 17:52
Skip adding the optional attribute to Conv nodes if any shape value is negative.

Signed-off-by: Javier Dehesa <javidcf@gmail.com>
Signed-off-by: Javier Dehesa <javidcf@gmail.com>
tf2onnx/onnx_opset/nn.py Outdated Show resolved Hide resolved
tf2onnx/onnx_opset/nn.py Outdated Show resolved Hide resolved
fatcat-z and others added 2 commits March 17, 2023 13:42
Fixed convolution kernel dimension checks

Signed-off-by: Javier Dehesa <javidcf@gmail.com>
@javidcf javidcf closed this May 23, 2023
@javidcf javidcf deleted the patch-1 branch May 23, 2023 14:52
@javidcf javidcf restored the patch-1 branch May 23, 2023 14:53
@javidcf javidcf reopened this May 23, 2023
@javidcf
Copy link
Contributor Author

javidcf commented May 23, 2023

Apologies, I had deleted this branch by mistake.

@fatcat-z
Copy link
Collaborator

Could you please resolve the conflict in test_backend.py file?

Signed-off-by: Javier Dehesa <javidcf@gmail.com>
Copy link
Collaborator

@fatcat-z fatcat-z left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contributions!

@fatcat-z fatcat-z merged commit 554d90a into onnx:main May 31, 2023
@javidcf javidcf deleted the patch-1 branch May 31, 2023 06:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Negative group attribute in generated convolution operations
2 participants