-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Frontend] Add support for aten::concat #16199
Conversation
@sweetcocoa Thank you for your PR! |
@mshr-h, I agree with your comments. |
Thank you @sweetcocoa for your PR! I agree the test is needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, Thanks!
@@ -719,10 +719,27 @@ def forward(self, *args): | |||
b = (args[0][:, :, 1] + 3) * 11 | |||
c = (args[0][:, :, 2] + 5) * 13 | |||
return torch.cat([t.unsqueeze(2) for t in [a, b, c]], 2) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove trailing whitespaces.
@@ -2893,6 +2911,7 @@ def forward(self, inp): | |||
@tvm.testing.uses_gpu | |||
def test_simple_rnn(): | |||
"""test_simple_rnn""" | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I suppose that all such white spaces are redundant. Could you please remove them in this test and in the tests below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I didn't see it, I've reverted it now.
03ddc08
to
6e23667
Compare
I don't think this CI error stems from this PR, can I try restarting it? |
Yes, you can comment with |
@tvm-bot rerun |
class Concatenate3(Module): | ||
# pylint: disable=missing-class-docstring | ||
def __init__(self): | ||
super().__init__() | ||
|
||
class _Concatenate(Module): | ||
def forward(self, *args): | ||
a = (args[0][:, :, 0] + 2) * 7 | ||
b = (args[0][:, :, 1] + 3) * 11 | ||
c = (args[0][:, :, 2] + 5) * 13 | ||
return torch.concat([t.unsqueeze(2) for t in [a, b, c]], 2) | ||
|
||
self.mod = _Concatenate() | ||
|
||
def forward(self, *args): | ||
return self.mod(*args) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you create a class in this way?
Will the following code work in the same way?
class Concatenate3(Module): | |
# pylint: disable=missing-class-docstring | |
def __init__(self): | |
super().__init__() | |
class _Concatenate(Module): | |
def forward(self, *args): | |
a = (args[0][:, :, 0] + 2) * 7 | |
b = (args[0][:, :, 1] + 3) * 11 | |
c = (args[0][:, :, 2] + 5) * 13 | |
return torch.concat([t.unsqueeze(2) for t in [a, b, c]], 2) | |
self.mod = _Concatenate() | |
def forward(self, *args): | |
return self.mod(*args) | |
class Concatenate3(Module): | |
def forward(self, *args): | |
a = (args[0][:, :, 0] + 2) * 7 | |
b = (args[0][:, :, 1] + 3) * 11 | |
c = (args[0][:, :, 2] + 5) * 13 | |
return torch.concat([t.unsqueeze(2) for t in [a, b, c]], 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@echuraev
The torch.concat
is preserved as aten::concat
only when it is in a nested module like this code. (In the most cases, It is converted to aten::cat
instead of aten::concat
.) I've tried to find a reason for this, but haven't found one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your reply! Could you please in this case specify it in the class description, instead of using # pylint: disable=missing-class-docstring
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thank you for your PR!
I think it is a quite simple problem,
aten::concat
is just an alias ofaten::cat
, but it is not supported.https://github.com/pytorch/pytorch/blob/3fbfa8cd0a5cefadb3f116c5cd0d60e96ab8c99e/aten/src/ATen/native/TensorShape.cpp#L667
If needed, I will add a minimal example to reproduce.