From 1353305154f797917404eeff7c3536b6e98098b1 Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Tue, 3 May 2022 10:06:52 -0700 Subject: [PATCH 01/15] Support gather of pytorch --- api/src/main/java/ai/djl/ndarray/NDArray.java | 11 ++++ .../ai/djl/ndarray/index/NDArrayIndexer.java | 8 +++ .../java/ai/djl/ndarray/index/NDIndex.java | 11 ++-- .../djl/ndarray/index/NDIndexFullGather.java | 52 +++++++++++++++++++ docs/get.md | 13 ++++- .../ai/djl/mxnet/engine/MxNDArrayIndexer.java | 7 +++ .../djl/pytorch/engine/PtNDArrayIndexer.java | 9 ++++ .../java/ai/djl/pytorch/jni/JniUtils.java | 10 +++- .../tensorflow/engine/TfNDArrayIndexer.java | 7 +++ .../tests/ndarray/NDIndexTest.java | 14 +++++ 10 files changed, 133 insertions(+), 9 deletions(-) create mode 100644 api/src/main/java/ai/djl/ndarray/index/NDIndexFullGather.java diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index 88eec5517f1..d20ad8d9530 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -14,6 +14,7 @@ import ai.djl.Device; import ai.djl.ndarray.index.NDIndex; +import ai.djl.ndarray.index.NDIndexFullGather; import ai.djl.ndarray.internal.NDArrayEx; import ai.djl.ndarray.internal.NDFormat; import ai.djl.ndarray.types.DataType; @@ -511,6 +512,16 @@ default NDArray get(NDIndex index) { return getNDArrayInternal().getIndexer().get(this, index); } + /** + * Returns a partial {@code NDArray} pointed by the indexed array. + * + * @param gather: {index, axis} + * @return the partial {@code NDArray} + */ + default NDArray get(NDIndexFullGather gather) { + return getNDArrayInternal().getIndexer().get(this, gather); + } + /** * Returns a partial {@code NDArray}. * diff --git a/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java b/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java index 31f4008120f..8afc77e76ce 100644 --- a/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java +++ b/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java @@ -77,6 +77,14 @@ public NDArray get(NDArray array, NDIndex index) { "get() currently supports all, fixed, and slices indices"); } + /** + * Gets the values of the array at the fullGather: (indices, axis). + * + * @param array the array to set + * @param gather the fullGather: (indices, axis) + */ + public abstract NDArray get(NDArray array, NDIndexFullGather gather); + /** * Sets the values of the array at the fullSlice with an array. * diff --git a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java index c271747bc99..09f877832e8 100644 --- a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java +++ b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java @@ -13,12 +13,7 @@ package ai.djl.ndarray.index; import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.index.dim.NDIndexAll; -import ai.djl.ndarray.index.dim.NDIndexBooleans; -import ai.djl.ndarray.index.dim.NDIndexElement; -import ai.djl.ndarray.index.dim.NDIndexFixed; -import ai.djl.ndarray.index.dim.NDIndexPick; -import ai.djl.ndarray.index.dim.NDIndexSlice; +import ai.djl.ndarray.index.dim.*; import ai.djl.ndarray.types.DataType; import java.util.ArrayList; import java.util.List; @@ -411,4 +406,8 @@ private Long parseSliceItem(String sliceItem, int argIndex, Object... args) { return Long.parseLong(sliceItem); } } + + public NDIndexFullGather gather(NDArray index, int axis) { + return new NDIndexFullGather(index, axis); + } } diff --git a/api/src/main/java/ai/djl/ndarray/index/NDIndexFullGather.java b/api/src/main/java/ai/djl/ndarray/index/NDIndexFullGather.java new file mode 100644 index 00000000000..9126a632087 --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/index/NDIndexFullGather.java @@ -0,0 +1,52 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.index; + +import ai.djl.ndarray.NDArray; + + +/** A simplified representation of a gather-based {@link NDIndex}-like class. */ +public final class NDIndexFullGather { + + private final NDArray indices; + private final int axis; + + /** + * Constructs a new {@link NDIndexFullGather}. + * + * @param indices the indices to gather + * @param axis the axis to gather from + */ + NDIndexFullGather(NDArray indices, int axis) { + this.indices = indices; + this.axis = axis; + } + + /** + * Returns the indices to gather. + * + * @return the indices to gather + */ + public NDArray getIndices() { + return indices; + } + + /** + * Returns the axis to gather. + * + * @return the axis to gather + */ + public int getAxis() { + return axis; + } +} diff --git a/docs/get.md b/docs/get.md index 39330e2f063..da14ab34a65 100644 --- a/docs/get.md +++ b/docs/get.md @@ -90,9 +90,17 @@ If you build following the above instructions, you will use the version of the c You can look here to find the [list of DJL releases](https://github.com/deepjavalibrary/djl/releases). -### Using build from source in another project +### Using built-from-source version in another project -If you have another project and want to use a custom version of DJL in it, you can do this by building from source. Once you have built DJL from source, run `./gradlew publishToMavenLocal`. This will install DJL to your local maven repository cache located on your filesystem at `~/.m2/repository`. When you publish here, it will use the same snapshot versions as used by the snapshot repository above for adding the DJL dependencies. +If you have another project and want to use a custom version of DJL in it, you can do the following. First, build DJL from source by running `./gradlew build` inside djl folder. Then run `./gradlew publishToMavenLocal`, which will install DJL to your local maven repository cache, located on your filesystem at `~/.m2/repository`. After publishing it here, you can add the DJL snapshot version dependencies as shown below +```groovy +dependencies { + implementation platform("ai.djl:bom:0.17.0-SNAPSHOT") +} +``` +This snapshot version is the same as the custom DJL repository. + +You also need to change directory to `djl/bom`. Then build and publish it to maven local same as was done in `djl`. From there, you may have to update the Maven or Gradle build of the project importing DJL to also look at the local maven repository cache for your locally published versions of DJL. For Maven, no changes are necessary. If you are using Gradle, you will have to add the maven local repository such as this [example](https://github.com/deepjavalibrary/djl-demo/blob/135c969d66d98d1672852e53a37e52ca1da3e325/pneumonia-detection/build.gradle#L11): @@ -101,3 +109,4 @@ repositories { mavenLocal() } ``` +Note that `mavenCentral()` may still be needed for applications like log4j and json. \ No newline at end of file diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java index 668d70d3087..f96c2768f92 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java @@ -15,6 +15,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.index.NDArrayIndexer; +import ai.djl.ndarray.index.NDIndexFullGather; import ai.djl.ndarray.index.full.NDIndexFullPick; import ai.djl.ndarray.index.full.NDIndexFullSlice; import ai.djl.ndarray.types.Shape; @@ -61,6 +62,12 @@ public NDArray get(NDArray array, NDIndexFullSlice fullSlice) { return result; } + /** {@inheritDoc} */ + @Override + public NDArray get(NDArray array, NDIndexFullGather gather) { + throw new UnsupportedOperationException("Not implemented yet."); + } + /** {@inheritDoc} */ @Override public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java index bd9f465a5d8..55487ab556d 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java @@ -15,6 +15,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.index.NDArrayIndexer; import ai.djl.ndarray.index.dim.NDIndexBooleans; +import ai.djl.ndarray.index.NDIndexFullGather; import ai.djl.ndarray.index.full.NDIndexFullPick; import ai.djl.ndarray.index.full.NDIndexFullSlice; import ai.djl.ndarray.types.Shape; @@ -37,6 +38,14 @@ public NDArray get(NDArray array, NDIndexFullPick fullPick) { manager.from(array), manager.from(fullPick.getIndices()), fullPick.getAxis()); } + /** {@inheritDoc} */ + @Override + public NDArray get(NDArray array, NDIndexFullGather gather) { + return JniUtils.gather( + manager.from(array), manager.from(gather.getIndices()), gather.getAxis() + ); + } + /** {@inheritDoc} */ @Override public NDArray get(NDArray array, NDIndexFullSlice fullSlice) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index a02bdc47427..ac176ea82f9 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -13,6 +13,7 @@ package ai.djl.pytorch.jni; import ai.djl.Device; +import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; @@ -342,7 +343,14 @@ public static void set(PtNDArray self, ByteBuffer data) { PyTorchLibrary.LIB.torchSet(self.getHandle(), data); } - public static PtNDArray pick(PtNDArray ndArray, PtNDArray index, long dim) { + public static PtNDArray gather(PtNDArray ndArray, PtNDArray index, long dim) { // dim was originally int. Why long here or inside torchGather? + return new PtNDArray( + ndArray.getManager(), + PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false) + ); + } + + public static PtNDArray pick(PtNDArray ndArray, PtNDArray index, long dim) { // dim was originally int. Why long here? Shape indexShape = index.getShape(); Shape ndShape = ndArray.getShape(); int shapeDims = indexShape.dimension(); diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayIndexer.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayIndexer.java index 4b2cdc744e7..48d6ea14fa3 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayIndexer.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayIndexer.java @@ -14,6 +14,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.index.NDArrayIndexer; +import ai.djl.ndarray.index.NDIndexFullGather; import ai.djl.ndarray.index.full.NDIndexFullPick; import ai.djl.ndarray.index.full.NDIndexFullSlice; @@ -32,6 +33,12 @@ public NDArray get(NDArray array, NDIndexFullPick fullPick) { throw new UnsupportedOperationException("Not implemented"); } + /** {@inheritDoc} */ + @Override + public NDArray get(NDArray array, NDIndexFullGather gather) { + throw new UnsupportedOperationException("Not implemented yet."); + } + /** {@inheritDoc} */ @Override public NDArray get(NDArray array, NDIndexFullSlice fullSlice) { diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index 5bf2e60106f..985481005e8 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -12,6 +12,7 @@ */ package ai.djl.integration.tests.ndarray; +import ai.djl.engine.Engine; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.index.NDIndex; @@ -52,6 +53,19 @@ public void testPick() { } } + @Test + public void testGather() { + Engine engine = Engine.getEngine("PyTorch"); + try (NDManager manager = engine.newBaseManager()) { + NDArray arr = manager.arange(20).reshape(-1, 4); + long [] idx = {0, 0, 2, 1, 1, 2}; + NDArray sel = manager.create(idx, new Shape(3, 2)); + NDArray actual = arr.get(new NDIndex().gather(sel, 1)); + NDArray expected = manager.create(new int[] {0, 0, 6, 5, 9, 10}, new Shape(3, 2)); + Assert.assertEquals(actual, expected); + } + } + @Test public void testGet() { try (NDManager manager = NDManager.newBaseManager()) { From f5cba83db6b2a848cf98efb7479c950f6e1a84a9 Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Tue, 3 May 2022 17:21:50 -0700 Subject: [PATCH 02/15] bugs in javadoc related --- api/src/main/java/ai/djl/ndarray/NDArray.java | 6 +++++- .../java/ai/djl/ndarray/index/NDArrayIndexer.java | 7 ++++--- api/src/main/java/ai/djl/ndarray/index/NDIndex.java | 5 +++++ docs/README.md | 1 + .../djl/integration/tests/ndarray/NDIndexTest.java | 12 +++++++++--- 5 files changed, 24 insertions(+), 7 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index d20ad8d9530..3db18eda385 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -514,8 +514,12 @@ default NDArray get(NDIndex index) { /** * Returns a partial {@code NDArray} pointed by the indexed array. + * Given NDArray arr, NDArray idx, and long axis, the output is + * out_{ijk} = arr_{idx_{ijk}, j, k} if axis=0 + * or arr_{i, idx_{ijk}, k} if axis=1 + * or arr_{i, j, idx_{ijk}} if axis=2 * - * @param gather: {index, axis} + * @param gather includes {index, axis} to specify the argument to gather * @return the partial {@code NDArray} */ default NDArray get(NDIndexFullGather gather) { diff --git a/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java b/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java index 8afc77e76ce..27471ed9542 100644 --- a/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java +++ b/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java @@ -78,10 +78,11 @@ public NDArray get(NDArray array, NDIndex index) { } /** - * Gets the values of the array at the fullGather: (indices, axis). + * Gets the values of the array according to NDIndexFullGather. * - * @param array the array to set - * @param gather the fullGather: (indices, axis) + * @param array the array to get from + * @param gather an NDIndexFullGather {NDArray indices, long axis} + * @return the subarray */ public abstract NDArray get(NDArray array, NDIndexFullGather gather); diff --git a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java index 09f877832e8..156400c74f4 100644 --- a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java +++ b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java @@ -407,6 +407,11 @@ private Long parseSliceItem(String sliceItem, int argIndex, Object... args) { } } + /** + * Returns the argument for gather + * + * @return NDIndexFullGather which includes {index, axis} + */ public NDIndexFullGather gather(NDArray index, int axis) { return new NDIndexFullGather(index, axis); } diff --git a/docs/README.md b/docs/README.md index 3f46ce5020c..ed81703bc4e 100644 --- a/docs/README.md +++ b/docs/README.md @@ -3,6 +3,7 @@ This folder contains examples and documentation for the Deep Java Library (DJL) project. ## [JavaDoc API Reference](https://javadoc.djl.ai/) +Note: when searching in JavaDoc, if your access is denied, please try removing the string `undefined` in the url. ## [Demos](https://djl.ai/website/demo.html) diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index 985481005e8..893a493768f 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -58,11 +58,17 @@ public void testGather() { Engine engine = Engine.getEngine("PyTorch"); try (NDManager manager = engine.newBaseManager()) { NDArray arr = manager.arange(20).reshape(-1, 4); - long [] idx = {0, 0, 2, 1, 1, 2}; + long[] idx = {0, 0, 2, 1, 1, 2}; NDArray sel = manager.create(idx, new Shape(3, 2)); NDArray actual = arr.get(new NDIndex().gather(sel, 1)); - NDArray expected = manager.create(new int[] {0, 0, 6, 5, 9, 10}, new Shape(3, 2)); - Assert.assertEquals(actual, expected); + NDArray expected = manager.create(new int[]{0, 0, 6, 5, 9, 10}, new Shape(3, 2)); + try { + Assert.assertEquals(actual, expected); + } catch (Exception e) { + System.out.println(e); + System.out.println(actual); + System.out.println(expected); + } } } From 6caba3e5c488813ce319914a7f626aa9a2b8614d Mon Sep 17 00:00:00 2001 From: Kexin Date: Tue, 3 May 2022 19:24:48 -0700 Subject: [PATCH 03/15] bugs in javadoc related --- api/src/main/java/ai/djl/ndarray/index/NDIndex.java | 5 ++++- .../ai/djl/integration/tests/ndarray/NDIndexTest.java | 10 +++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java index 156400c74f4..11fe3a02e27 100644 --- a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java +++ b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java @@ -408,8 +408,11 @@ private Long parseSliceItem(String sliceItem, int argIndex, Object... args) { } /** - * Returns the argument for gather + * Returns the argument for gather. * + * @param index the indices should be NDArray. Each entry is an index of the axis specified + * next. The shape of the returned NDArray is according to the shape of index. + * @param axis specifies the axis of the index entry. * @return NDIndexFullGather which includes {index, axis} */ public NDIndexFullGather gather(NDArray index, int axis) { diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index 893a493768f..c32817bdcde 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -61,13 +61,13 @@ public void testGather() { long[] idx = {0, 0, 2, 1, 1, 2}; NDArray sel = manager.create(idx, new Shape(3, 2)); NDArray actual = arr.get(new NDIndex().gather(sel, 1)); - NDArray expected = manager.create(new int[]{0, 0, 6, 5, 9, 10}, new Shape(3, 2)); + NDArray expected = manager.create(new float[]{0f, 0, 6, 5, 9, 10}, new Shape(3, 2)); try { Assert.assertEquals(actual, expected); - } catch (Exception e) { - System.out.println(e); - System.out.println(actual); - System.out.println(expected); + } + catch (AssertionError e) { + System.out.println(e.getMessage()); + throw new AssertionError(e); } } } From cf61fa9e4d4f0916eba7ef414bedc08df1965f82 Mon Sep 17 00:00:00 2001 From: Kexin Date: Tue, 3 May 2022 19:28:24 -0700 Subject: [PATCH 04/15] bugs in javadoc related --- .../main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index c32817bdcde..c61c050df68 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -61,7 +61,7 @@ public void testGather() { long[] idx = {0, 0, 2, 1, 1, 2}; NDArray sel = manager.create(idx, new Shape(3, 2)); NDArray actual = arr.get(new NDIndex().gather(sel, 1)); - NDArray expected = manager.create(new float[]{0f, 0, 6, 5, 9, 10}, new Shape(3, 2)); + NDArray expected = manager.create(new int[]{0, 0, 6, 5, 9, 10}, new Shape(3, 2)); try { Assert.assertEquals(actual, expected); } From 9fa61c1d7ab25940079d13e3e07af05fb32f58ea Mon Sep 17 00:00:00 2001 From: Kexin Date: Tue, 3 May 2022 19:54:49 -0700 Subject: [PATCH 05/15] import .* is forbidden --- api/src/main/java/ai/djl/ndarray/index/NDIndex.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java index 11fe3a02e27..1434f580437 100644 --- a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java +++ b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java @@ -13,7 +13,12 @@ package ai.djl.ndarray.index; import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.index.dim.*; +import ai.djl.ndarray.index.dim.NDIndexBooleans; +import ai.djl.ndarray.index.dim.NDIndexFixed; +import ai.djl.ndarray.index.dim.NDIndexAll; +import ai.djl.ndarray.index.dim.NDIndexElement; +import ai.djl.ndarray.index.dim.NDIndexPick; +import ai.djl.ndarray.index.dim.NDIndexSlice; import ai.djl.ndarray.types.DataType; import java.util.ArrayList; import java.util.List; From 8b6526bb91f3a58ec468f126105ec21aa4fd8fb6 Mon Sep 17 00:00:00 2001 From: Kexin Date: Tue, 3 May 2022 19:59:25 -0700 Subject: [PATCH 06/15] windows assertion error on int array --- .../djl/integration/tests/ndarray/NDIndexTest.java | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index c61c050df68..ea15d1ad5ac 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -57,18 +57,12 @@ public void testPick() { public void testGather() { Engine engine = Engine.getEngine("PyTorch"); try (NDManager manager = engine.newBaseManager()) { - NDArray arr = manager.arange(20).reshape(-1, 4); + NDArray arr = manager.arange(20f).reshape(-1, 4); long[] idx = {0, 0, 2, 1, 1, 2}; NDArray sel = manager.create(idx, new Shape(3, 2)); NDArray actual = arr.get(new NDIndex().gather(sel, 1)); - NDArray expected = manager.create(new int[]{0, 0, 6, 5, 9, 10}, new Shape(3, 2)); - try { - Assert.assertEquals(actual, expected); - } - catch (AssertionError e) { - System.out.println(e.getMessage()); - throw new AssertionError(e); - } + NDArray expected = manager.create(new float[]{0, 0, 6, 5, 9, 10}, new Shape(3, 2)); + Assert.assertEquals(actual, expected); } } From 6a2d5db211dba676ca615688adf534ad2237d426 Mon Sep 17 00:00:00 2001 From: Kexin Date: Tue, 3 May 2022 22:39:49 -0700 Subject: [PATCH 07/15] import order is wrong --- api/src/main/java/ai/djl/ndarray/index/NDIndex.java | 4 ++-- .../djl/integration/tests/ndarray/NDIndexTest.java | 13 ------------- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java index 1434f580437..b4ad51f8fba 100644 --- a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java +++ b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java @@ -13,10 +13,10 @@ package ai.djl.ndarray.index; import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.index.dim.NDIndexBooleans; -import ai.djl.ndarray.index.dim.NDIndexFixed; import ai.djl.ndarray.index.dim.NDIndexAll; +import ai.djl.ndarray.index.dim.NDIndexBooleans; import ai.djl.ndarray.index.dim.NDIndexElement; +import ai.djl.ndarray.index.dim.NDIndexFixed; import ai.djl.ndarray.index.dim.NDIndexPick; import ai.djl.ndarray.index.dim.NDIndexSlice; import ai.djl.ndarray.types.DataType; diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index ea15d1ad5ac..f5432642c72 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -53,19 +53,6 @@ public void testPick() { } } - @Test - public void testGather() { - Engine engine = Engine.getEngine("PyTorch"); - try (NDManager manager = engine.newBaseManager()) { - NDArray arr = manager.arange(20f).reshape(-1, 4); - long[] idx = {0, 0, 2, 1, 1, 2}; - NDArray sel = manager.create(idx, new Shape(3, 2)); - NDArray actual = arr.get(new NDIndex().gather(sel, 1)); - NDArray expected = manager.create(new float[]{0, 0, 6, 5, 9, 10}, new Shape(3, 2)); - Assert.assertEquals(actual, expected); - } - } - @Test public void testGet() { try (NDManager manager = NDManager.newBaseManager()) { From 306a25eabf4ec8ac60056e82a7c3a9cba5e0bd9b Mon Sep 17 00:00:00 2001 From: Kexin Date: Tue, 3 May 2022 23:26:03 -0700 Subject: [PATCH 08/15] api:formatJava --- api/src/main/java/ai/djl/ndarray/NDArray.java | 8 +++----- api/src/main/java/ai/djl/ndarray/index/NDIndex.java | 2 +- .../main/java/ai/djl/ndarray/index/NDIndexFullGather.java | 5 ++--- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index 3db18eda385..2eb866245d7 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -513,11 +513,9 @@ default NDArray get(NDIndex index) { } /** - * Returns a partial {@code NDArray} pointed by the indexed array. - * Given NDArray arr, NDArray idx, and long axis, the output is - * out_{ijk} = arr_{idx_{ijk}, j, k} if axis=0 - * or arr_{i, idx_{ijk}, k} if axis=1 - * or arr_{i, j, idx_{ijk}} if axis=2 + * Returns a partial {@code NDArray} pointed by the indexed array. Given NDArray arr, NDArray + * idx, and long axis, the output is out_{ijk} = arr_{idx_{ijk}, j, k} if axis=0 or arr_{i, + * idx_{ijk}, k} if axis=1 or arr_{i, j, idx_{ijk}} if axis=2 * * @param gather includes {index, axis} to specify the argument to gather * @return the partial {@code NDArray} diff --git a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java index b4ad51f8fba..811b4a440f6 100644 --- a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java +++ b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java @@ -416,7 +416,7 @@ private Long parseSliceItem(String sliceItem, int argIndex, Object... args) { * Returns the argument for gather. * * @param index the indices should be NDArray. Each entry is an index of the axis specified - * next. The shape of the returned NDArray is according to the shape of index. + * next. The shape of the returned NDArray is according to the shape of index. * @param axis specifies the axis of the index entry. * @return NDIndexFullGather which includes {index, axis} */ diff --git a/api/src/main/java/ai/djl/ndarray/index/NDIndexFullGather.java b/api/src/main/java/ai/djl/ndarray/index/NDIndexFullGather.java index 9126a632087..95fef336020 100644 --- a/api/src/main/java/ai/djl/ndarray/index/NDIndexFullGather.java +++ b/api/src/main/java/ai/djl/ndarray/index/NDIndexFullGather.java @@ -14,12 +14,11 @@ import ai.djl.ndarray.NDArray; - /** A simplified representation of a gather-based {@link NDIndex}-like class. */ public final class NDIndexFullGather { - private final NDArray indices; - private final int axis; + private NDArray indices; + private int axis; /** * Constructs a new {@link NDIndexFullGather}. From e341d85ac8d4ce15e4b54e1b6820da4d954d1f0f Mon Sep 17 00:00:00 2001 From: Kexin Date: Wed, 4 May 2022 00:30:24 -0700 Subject: [PATCH 09/15] api:formatJava --- .../main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java | 5 ++--- .../src/main/java/ai/djl/pytorch/jni/JniUtils.java | 8 +++----- .../ai/djl/integration/tests/ndarray/NDIndexTest.java | 1 - 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java index 55487ab556d..3424a9c8c3c 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java @@ -14,8 +14,8 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.index.NDArrayIndexer; -import ai.djl.ndarray.index.dim.NDIndexBooleans; import ai.djl.ndarray.index.NDIndexFullGather; +import ai.djl.ndarray.index.dim.NDIndexBooleans; import ai.djl.ndarray.index.full.NDIndexFullPick; import ai.djl.ndarray.index.full.NDIndexFullSlice; import ai.djl.ndarray.types.Shape; @@ -42,8 +42,7 @@ public NDArray get(NDArray array, NDIndexFullPick fullPick) { @Override public NDArray get(NDArray array, NDIndexFullGather gather) { return JniUtils.gather( - manager.from(array), manager.from(gather.getIndices()), gather.getAxis() - ); + manager.from(array), manager.from(gather.getIndices()), gather.getAxis()); } /** {@inheritDoc} */ diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index ac176ea82f9..170b2c35bbd 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -13,7 +13,6 @@ package ai.djl.pytorch.jni; import ai.djl.Device; -import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; @@ -343,14 +342,13 @@ public static void set(PtNDArray self, ByteBuffer data) { PyTorchLibrary.LIB.torchSet(self.getHandle(), data); } - public static PtNDArray gather(PtNDArray ndArray, PtNDArray index, long dim) { // dim was originally int. Why long here or inside torchGather? + public static PtNDArray gather(PtNDArray ndArray, PtNDArray index, long dim) { return new PtNDArray( ndArray.getManager(), - PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false) - ); + PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false)); } - public static PtNDArray pick(PtNDArray ndArray, PtNDArray index, long dim) { // dim was originally int. Why long here? + public static PtNDArray pick(PtNDArray ndArray, PtNDArray index, long dim) { Shape indexShape = index.getShape(); Shape ndShape = ndArray.getShape(); int shapeDims = indexShape.dimension(); diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index f5432642c72..5bf2e60106f 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -12,7 +12,6 @@ */ package ai.djl.integration.tests.ndarray; -import ai.djl.engine.Engine; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.index.NDIndex; From d1fc3a9391cee4f64ab9fdea870a11544741dd43 Mon Sep 17 00:00:00 2001 From: Kexin Date: Wed, 4 May 2022 10:42:32 -0700 Subject: [PATCH 10/15] Add testGather but w/o specifying pytorch engine --- .../djl/integration/tests/ndarray/NDIndexTest.java | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index 5bf2e60106f..9c48bc1699a 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -52,6 +52,18 @@ public void testPick() { } } + @Test + public void testGather() { + try (NDManager manager = NDManager.newBaseManager()) { + NDArray arr = manager.arange(20f).reshape(-1, 4); + long[] idx = {0, 0, 2, 1, 1, 2}; + NDArray sel = manager.create(idx, new Shape(3, 2)); + NDArray actual = arr.get(new NDIndex().gather(sel, 1)); + NDArray expected = manager.create(new float[] {0, 0, 6, 5, 9, 10}, new Shape(3, 2)); + Assert.assertEquals(actual, expected); + } + } + @Test public void testGet() { try (NDManager manager = NDManager.newBaseManager()) { From 7bd40e52d981216fa5c175b1c862e816c1894cad Mon Sep 17 00:00:00 2001 From: Kexin Date: Thu, 5 May 2022 18:57:54 -0700 Subject: [PATCH 11/15] update gather, update gradle version, update two docs --- api/src/main/java/ai/djl/ndarray/NDArray.java | 24 ++++----- .../java/ai/djl/ndarray/NDArrayAdapter.java | 6 +++ .../ai/djl/ndarray/index/NDArrayIndexer.java | 9 ---- .../java/ai/djl/ndarray/index/NDIndex.java | 12 ----- .../djl/ndarray/index/NDIndexFullGather.java | 51 ------------------- docs/README.md | 1 + docs/get.md | 3 ++ .../java/ai/djl/mxnet/engine/MxNDArray.java | 6 +++ .../ai/djl/mxnet/engine/MxNDArrayIndexer.java | 7 --- .../java/ai/djl/pytorch/engine/PtNDArray.java | 9 ++++ .../djl/pytorch/engine/PtNDArrayIndexer.java | 8 --- .../ai/djl/tensorflow/engine/TfNDArray.java | 6 +++ .../tensorflow/engine/TfNDArrayIndexer.java | 7 --- gradle/wrapper/gradle-wrapper.properties | 2 +- .../tests/ndarray/NDIndexTest.java | 8 +-- 15 files changed, 48 insertions(+), 111 deletions(-) delete mode 100644 api/src/main/java/ai/djl/ndarray/index/NDIndexFullGather.java diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index 2eb866245d7..c35c7f09545 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -14,7 +14,6 @@ import ai.djl.Device; import ai.djl.ndarray.index.NDIndex; -import ai.djl.ndarray.index.NDIndexFullGather; import ai.djl.ndarray.internal.NDArrayEx; import ai.djl.ndarray.internal.NDFormat; import ai.djl.ndarray.types.DataType; @@ -512,18 +511,6 @@ default NDArray get(NDIndex index) { return getNDArrayInternal().getIndexer().get(this, index); } - /** - * Returns a partial {@code NDArray} pointed by the indexed array. Given NDArray arr, NDArray - * idx, and long axis, the output is out_{ijk} = arr_{idx_{ijk}, j, k} if axis=0 or arr_{i, - * idx_{ijk}, k} if axis=1 or arr_{i, j, idx_{ijk}} if axis=2 - * - * @param gather includes {index, axis} to specify the argument to gather - * @return the partial {@code NDArray} - */ - default NDArray get(NDIndexFullGather gather) { - return getNDArrayInternal().getIndexer().get(this, gather); - } - /** * Returns a partial {@code NDArray}. * @@ -558,6 +545,17 @@ default NDArray get(NDArray index) { return get(new NDIndex().addBooleanIndex(index)); } + /** + * Returns a partial {@code NDArray} pointed by the indexed array. Given NDArray arr, NDArray + * idx, and long axis, the output is out_{ijk} = arr_{idx_{ijk}, j, k} if axis=0 or arr_{i, + * idx_{ijk}, k} if axis=1 or arr_{i, j, idx_{ijk}} if axis=2 + * + * @param index picks the elements of an NDArray to the same position as index + * @param axis the entries of index are indices of axis + * @return the partial {@code NDArray} of the same shape as index + */ + NDArray gather(NDArray index, int axis); + /** * Returns a scalar {@code NDArray} corresponding to a single element. * diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java index 1a30edb19cf..46f2634c351 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java @@ -176,6 +176,12 @@ public String[] toStringArray(Charset charset) { throw new UnsupportedOperationException(UNSUPPORTED_MSG); } + /** {@inheritDoc} */ + @Override + public NDArray gather(NDArray index, int axis) { + throw new UnsupportedOperationException(UNSUPPORTED_MSG); + } + /** {@inheritDoc} */ @Override public void set(Buffer data) { diff --git a/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java b/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java index 27471ed9542..31f4008120f 100644 --- a/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java +++ b/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java @@ -77,15 +77,6 @@ public NDArray get(NDArray array, NDIndex index) { "get() currently supports all, fixed, and slices indices"); } - /** - * Gets the values of the array according to NDIndexFullGather. - * - * @param array the array to get from - * @param gather an NDIndexFullGather {NDArray indices, long axis} - * @return the subarray - */ - public abstract NDArray get(NDArray array, NDIndexFullGather gather); - /** * Sets the values of the array at the fullSlice with an array. * diff --git a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java index 811b4a440f6..c271747bc99 100644 --- a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java +++ b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java @@ -411,16 +411,4 @@ private Long parseSliceItem(String sliceItem, int argIndex, Object... args) { return Long.parseLong(sliceItem); } } - - /** - * Returns the argument for gather. - * - * @param index the indices should be NDArray. Each entry is an index of the axis specified - * next. The shape of the returned NDArray is according to the shape of index. - * @param axis specifies the axis of the index entry. - * @return NDIndexFullGather which includes {index, axis} - */ - public NDIndexFullGather gather(NDArray index, int axis) { - return new NDIndexFullGather(index, axis); - } } diff --git a/api/src/main/java/ai/djl/ndarray/index/NDIndexFullGather.java b/api/src/main/java/ai/djl/ndarray/index/NDIndexFullGather.java deleted file mode 100644 index 95fef336020..00000000000 --- a/api/src/main/java/ai/djl/ndarray/index/NDIndexFullGather.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.ndarray.index; - -import ai.djl.ndarray.NDArray; - -/** A simplified representation of a gather-based {@link NDIndex}-like class. */ -public final class NDIndexFullGather { - - private NDArray indices; - private int axis; - - /** - * Constructs a new {@link NDIndexFullGather}. - * - * @param indices the indices to gather - * @param axis the axis to gather from - */ - NDIndexFullGather(NDArray indices, int axis) { - this.indices = indices; - this.axis = axis; - } - - /** - * Returns the indices to gather. - * - * @return the indices to gather - */ - public NDArray getIndices() { - return indices; - } - - /** - * Returns the axis to gather. - * - * @return the axis to gather - */ - public int getAxis() { - return axis; - } -} diff --git a/docs/README.md b/docs/README.md index ed81703bc4e..22b7815f680 100644 --- a/docs/README.md +++ b/docs/README.md @@ -3,6 +3,7 @@ This folder contains examples and documentation for the Deep Java Library (DJL) project. ## [JavaDoc API Reference](https://javadoc.djl.ai/) + Note: when searching in JavaDoc, if your access is denied, please try removing the string `undefined` in the url. ## [Demos](https://djl.ai/website/demo.html) diff --git a/docs/get.md b/docs/get.md index da14ab34a65..056c194455d 100644 --- a/docs/get.md +++ b/docs/get.md @@ -93,11 +93,13 @@ You can look here to find the [list of DJL releases](https://github.com/deepjava ### Using built-from-source version in another project If you have another project and want to use a custom version of DJL in it, you can do the following. First, build DJL from source by running `./gradlew build` inside djl folder. Then run `./gradlew publishToMavenLocal`, which will install DJL to your local maven repository cache, located on your filesystem at `~/.m2/repository`. After publishing it here, you can add the DJL snapshot version dependencies as shown below + ```groovy dependencies { implementation platform("ai.djl:bom:0.17.0-SNAPSHOT") } ``` + This snapshot version is the same as the custom DJL repository. You also need to change directory to `djl/bom`. Then build and publish it to maven local same as was done in `djl`. @@ -109,4 +111,5 @@ repositories { mavenLocal() } ``` + Note that `mavenCentral()` may still be needed for applications like log4j and json. \ No newline at end of file diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index 68ab1b61b6f..f786b022930 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -306,6 +306,12 @@ public void set(Buffer data) { JnaUtils.syncCopyFromCPU(getHandle(), buf, size); } + /** {@inheritDoc} */ + @Override + public NDArray gather(NDArray index, int axis) { + throw new UnsupportedOperationException("Not implemented yet."); + } + /** {@inheritDoc} */ @Override public void copyTo(NDArray ndArray) { diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java index f96c2768f92..668d70d3087 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java @@ -15,7 +15,6 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.index.NDArrayIndexer; -import ai.djl.ndarray.index.NDIndexFullGather; import ai.djl.ndarray.index.full.NDIndexFullPick; import ai.djl.ndarray.index.full.NDIndexFullSlice; import ai.djl.ndarray.types.Shape; @@ -62,12 +61,6 @@ public NDArray get(NDArray array, NDIndexFullSlice fullSlice) { return result; } - /** {@inheritDoc} */ - @Override - public NDArray get(NDArray array, NDIndexFullGather gather) { - throw new UnsupportedOperationException("Not implemented yet."); - } - /** {@inheritDoc} */ @Override public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index ebad116054b..a1cacb83e93 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -239,6 +239,15 @@ public NDArray get(long... indices) { return JniUtils.getItem(this, indices); } + /** {@inheritDoc} */ + @Override + public NDArray gather(NDArray index, int axis) { + if (!(index instanceof PtNDArray)) { + throw new IllegalArgumentException("Only PtNDArray is supported."); + } + return JniUtils.gather(this, (PtNDArray) index, axis); + } + /** {@inheritDoc} */ @Override public void copyTo(NDArray array) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java index 3424a9c8c3c..bd9f465a5d8 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java @@ -14,7 +14,6 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.index.NDArrayIndexer; -import ai.djl.ndarray.index.NDIndexFullGather; import ai.djl.ndarray.index.dim.NDIndexBooleans; import ai.djl.ndarray.index.full.NDIndexFullPick; import ai.djl.ndarray.index.full.NDIndexFullSlice; @@ -38,13 +37,6 @@ public NDArray get(NDArray array, NDIndexFullPick fullPick) { manager.from(array), manager.from(fullPick.getIndices()), fullPick.getAxis()); } - /** {@inheritDoc} */ - @Override - public NDArray get(NDArray array, NDIndexFullGather gather) { - return JniUtils.gather( - manager.from(array), manager.from(gather.getIndices()), gather.getAxis()); - } - /** {@inheritDoc} */ @Override public NDArray get(NDArray array, NDIndexFullSlice fullSlice) { diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index 3d5c8f09c9d..026680b97dd 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -206,6 +206,12 @@ public void set(Buffer data) { JavacppUtils.setByteBuffer(getHandle(), buf); } + /** {@inheritDoc} */ + @Override + public NDArray gather(NDArray index, int axis) { + throw new UnsupportedOperationException("Not implemented yet."); + } + /** {@inheritDoc} */ @Override public void attach(NDManager manager) { diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayIndexer.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayIndexer.java index 48d6ea14fa3..4b2cdc744e7 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayIndexer.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayIndexer.java @@ -14,7 +14,6 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.index.NDArrayIndexer; -import ai.djl.ndarray.index.NDIndexFullGather; import ai.djl.ndarray.index.full.NDIndexFullPick; import ai.djl.ndarray.index.full.NDIndexFullSlice; @@ -33,12 +32,6 @@ public NDArray get(NDArray array, NDIndexFullPick fullPick) { throw new UnsupportedOperationException("Not implemented"); } - /** {@inheritDoc} */ - @Override - public NDArray get(NDArray array, NDIndexFullGather gather) { - throw new UnsupportedOperationException("Not implemented yet."); - } - /** {@inheritDoc} */ @Override public NDArray get(NDArray array, NDIndexFullSlice fullSlice) { diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index f1ffe6612c1..fdd2345a06f 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -3,4 +3,4 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-7.4.1-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-7.4.2-bin.zip diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index 9c48bc1699a..fdb7a6d833a 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -12,6 +12,7 @@ */ package ai.djl.integration.tests.ndarray; +import ai.djl.engine.Engine; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.index.NDIndex; @@ -54,11 +55,12 @@ public void testPick() { @Test public void testGather() { - try (NDManager manager = NDManager.newBaseManager()) { + Engine engine = Engine.getEngine("PyTorch"); + try (NDManager manager = engine.newBaseManager()) { NDArray arr = manager.arange(20f).reshape(-1, 4); long[] idx = {0, 0, 2, 1, 1, 2}; - NDArray sel = manager.create(idx, new Shape(3, 2)); - NDArray actual = arr.get(new NDIndex().gather(sel, 1)); + NDArray index = manager.create(idx, new Shape(3, 2)); + NDArray actual = arr.gather(index, 1); NDArray expected = manager.create(new float[] {0, 0, 6, 5, 9, 10}, new Shape(3, 2)); Assert.assertEquals(actual, expected); } From 98a1de437ccb8b6119014af41691f37ec3d79334 Mon Sep 17 00:00:00 2001 From: Kexin Date: Thu, 5 May 2022 21:58:15 -0700 Subject: [PATCH 12/15] change testGather to MXNet --- .../main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index fdb7a6d833a..642bad98183 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -55,7 +55,7 @@ public void testPick() { @Test public void testGather() { - Engine engine = Engine.getEngine("PyTorch"); + Engine engine = Engine.getEngine("MXNet"); try (NDManager manager = engine.newBaseManager()) { NDArray arr = manager.arange(20f).reshape(-1, 4); long[] idx = {0, 0, 2, 1, 1, 2}; From 822c27b50432c0cdf89d46eb99b123323fe503f9 Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Fri, 6 May 2022 09:01:40 -0700 Subject: [PATCH 13/15] testRequirement.notWindows --- .../java/ai/djl/integration/tests/ndarray/NDIndexTest.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index 642bad98183..748f382e6c6 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -17,6 +17,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.index.NDIndex; import ai.djl.ndarray.types.Shape; +import ai.djl.testing.TestRequirements; import org.testng.Assert; import org.testng.annotations.Test; @@ -55,7 +56,11 @@ public void testPick() { @Test public void testGather() { - Engine engine = Engine.getEngine("MXNet"); + // Currently in windows gradle cannot find all the engines. + // In the dependency, changing runtimeOnly to api however will remedy the problem. + // TODO: remove this when gradle problem is fixed. + TestRequirements.notWindows(); + Engine engine = Engine.getEngine("PyTorch"); try (NDManager manager = engine.newBaseManager()) { NDArray arr = manager.arange(20f).reshape(-1, 4); long[] idx = {0, 0, 2, 1, 1, 2}; From 8cb4ba28286a7d3c99fb71d33c5fe8c5efb4d9ef Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Fri, 6 May 2022 11:01:59 -0700 Subject: [PATCH 14/15] testGather --- .../java/ai/djl/integration/tests/ndarray/NDIndexTest.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index 748f382e6c6..3e823515251 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -57,11 +57,10 @@ public void testPick() { @Test public void testGather() { // Currently in windows gradle cannot find all the engines. - // In the dependency, changing runtimeOnly to api however will remedy the problem. + // In the dependencies, changing runtimeOnly to api however will remedy the problem. // TODO: remove this when gradle problem is fixed. TestRequirements.notWindows(); - Engine engine = Engine.getEngine("PyTorch"); - try (NDManager manager = engine.newBaseManager()) { + try (NDManager manager = NDManager.newBaseManager()) { NDArray arr = manager.arange(20f).reshape(-1, 4); long[] idx = {0, 0, 2, 1, 1, 2}; NDArray index = manager.create(idx, new Shape(3, 2)); From ac5a706bf079c6d2b184998a75afa7779eb8bd73 Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Fri, 6 May 2022 11:29:17 -0700 Subject: [PATCH 15/15] testGather format --- .../main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java | 1 - 1 file changed, 1 deletion(-) diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index 3e823515251..a954f3e0fd0 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -12,7 +12,6 @@ */ package ai.djl.integration.tests.ndarray; -import ai.djl.engine.Engine; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.index.NDIndex;