-
Notifications
You must be signed in to change notification settings - Fork 655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support gather of pytorch #1622
Changes from 14 commits
1353305
f5cba83
6caba3e
cf61fa9
9fa61c1
8b6526b
6a2d5db
306a25e
e341d85
d1fc3a9
7bd40e5
b887bc1
98a1de4
822c27b
8cb4ba2
ac5a706
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,10 +12,12 @@ | |
*/ | ||
package ai.djl.integration.tests.ndarray; | ||
|
||
import ai.djl.engine.Engine; | ||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.ndarray.NDManager; | ||
import ai.djl.ndarray.index.NDIndex; | ||
import ai.djl.ndarray.types.Shape; | ||
import ai.djl.testing.TestRequirements; | ||
import org.testng.Assert; | ||
import org.testng.annotations.Test; | ||
|
||
|
@@ -52,6 +54,23 @@ public void testPick() { | |
} | ||
} | ||
|
||
@Test | ||
public void testGather() { | ||
// Currently in windows gradle cannot find all the engines. | ||
// In the dependency, changing runtimeOnly to api however will remedy the problem. | ||
// TODO: remove this when gradle problem is fixed. | ||
TestRequirements.notWindows(); | ||
Engine engine = Engine.getEngine("PyTorch"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The integration module is supposed to be for "engine agnostic tests". So, you shouldn't need to specify "PyTorch" here. For the other engines, it will throw a In the future, we should be able to implement gather for the other engines and then it will just run the test because it is already there. |
||
try (NDManager manager = engine.newBaseManager()) { | ||
NDArray arr = manager.arange(20f).reshape(-1, 4); | ||
long[] idx = {0, 0, 2, 1, 1, 2}; | ||
NDArray index = manager.create(idx, new Shape(3, 2)); | ||
NDArray actual = arr.gather(index, 1); | ||
NDArray expected = manager.create(new float[] {0, 0, 6, 5, 9, 10}, new Shape(3, 2)); | ||
Assert.assertEquals(actual, expected); | ||
} | ||
} | ||
|
||
@Test | ||
public void testGet() { | ||
try (NDManager manager = NDManager.newBaseManager()) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this comment still valid?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably. It seems like something is wrong with
runtimeOnly
dependencies on windows gradle.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes after updating gradle, on GitHub the test stills fails for the same reason