-
Notifications
You must be signed in to change notification settings - Fork 655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support gather of pytorch #1622
Conversation
Codecov Report
@@ Coverage Diff @@
## master #1622 +/- ##
============================================
- Coverage 72.08% 70.87% -1.21%
- Complexity 5126 5433 +307
============================================
Files 473 507 +34
Lines 21970 23767 +1797
Branches 2351 2588 +237
============================================
+ Hits 15838 16846 +1008
- Misses 4925 5631 +706
- Partials 1207 1290 +83
Continue to review full report at Codecov.
|
@@ -0,0 +1,51 @@ | |||
/* | |||
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | |
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Thanks!
private int axis; | ||
|
||
/** | ||
* Constructs a new {@link NDIndexFullGather}. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* Constructs a new {@link NDIndexFullGather}. | |
* Constructs a new {@code NDIndexFullGather} instance. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Thanks!
e25a16f
to
d1fc3a9
Compare
// In the dependency, changing runtimeOnly to api however will remedy the problem. | ||
// TODO: remove this when gradle problem is fixed. | ||
TestRequirements.notWindows(); | ||
Engine engine = Engine.getEngine("PyTorch"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The integration module is supposed to be for "engine agnostic tests". So, you shouldn't need to specify "PyTorch" here. For the other engines, it will throw a UnsupportedOperationException
instead of running successfully. But, our test runner is designed to treat UnsupportedOperationException
as a special test result, not a failure. It's similar to how a testng SkipException
counts as a skipped test result rather than a test failure.
In the future, we should be able to implement gather for the other engines and then it will just run the test because it is already there.
@@ -52,6 +53,22 @@ public void testPick() { | |||
} | |||
} | |||
|
|||
@Test | |||
public void testGather() { | |||
// Currently in windows gradle cannot find all the engines. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this comment still valid?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably. It seems like something is wrong with runtimeOnly
dependencies on windows gradle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes after updating gradle, on GitHub the test stills fails for the same reason
Description
This originiates from issue #248. This is regarding support of advanced indexing.
As part of this support, this PR implements
gather
. It's the same as the gather inpytorh and numpy :
But
gather
is used to index across axis, including the preceding axis. This conflicts with the linear storage ofList<NDIndexElement>
insideNDIndex
. Sogather
is treated parallel toNDIndex
when fed intoget
.It currently supports pytorch engine only. MXNet and TensorFlow have different defination of gather.
Similar functions in numpy / pytorch / mxnet can also be implemented: e.g.
torch.take
, and advanced indexing likearr[:, [0 ,2]]
andarr[idx_r, idx_c]