Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support gather of pytorch #1622

Merged
merged 16 commits into from
May 6, 2022
11 changes: 11 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,17 @@ default NDArray get(NDArray index) {
return get(new NDIndex().addBooleanIndex(index));
}

/**
* Returns a partial {@code NDArray} pointed by the indexed array. Given NDArray arr, NDArray
* idx, and long axis, the output is out_{ijk} = arr_{idx_{ijk}, j, k} if axis=0 or arr_{i,
* idx_{ijk}, k} if axis=1 or arr_{i, j, idx_{ijk}} if axis=2
*
* @param index picks the elements of an NDArray to the same position as index
* @param axis the entries of index are indices of axis
* @return the partial {@code NDArray} of the same shape as index
*/
NDArray gather(NDArray index, int axis);

/**
* Returns a scalar {@code NDArray} corresponding to a single element.
*
Expand Down
6 changes: 6 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ public String[] toStringArray(Charset charset) {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
public NDArray gather(NDArray index, int axis) {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
public void set(Buffer data) {
Expand Down
2 changes: 2 additions & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ This folder contains examples and documentation for the Deep Java Library (DJL)

## [JavaDoc API Reference](https://javadoc.djl.ai/)

Note: when searching in JavaDoc, if your access is denied, please try removing the string `undefined` in the url.

## [Demos](https://djl.ai/website/demo.html)

## Cheat sheet
Expand Down
16 changes: 14 additions & 2 deletions docs/get.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,19 @@ If you build following the above instructions, you will use the version of the c

You can look here to find the [list of DJL releases](https://github.com/deepjavalibrary/djl/releases).

### Using build from source in another project
### Using built-from-source version in another project

If you have another project and want to use a custom version of DJL in it, you can do this by building from source. Once you have built DJL from source, run `./gradlew publishToMavenLocal`. This will install DJL to your local maven repository cache located on your filesystem at `~/.m2/repository`. When you publish here, it will use the same snapshot versions as used by the snapshot repository above for adding the DJL dependencies.
If you have another project and want to use a custom version of DJL in it, you can do the following. First, build DJL from source by running `./gradlew build` inside djl folder. Then run `./gradlew publishToMavenLocal`, which will install DJL to your local maven repository cache, located on your filesystem at `~/.m2/repository`. After publishing it here, you can add the DJL snapshot version dependencies as shown below

```groovy
dependencies {
implementation platform("ai.djl:bom:0.17.0-SNAPSHOT")
}
```

This snapshot version is the same as the custom DJL repository.

You also need to change directory to `djl/bom`. Then build and publish it to maven local same as was done in `djl`.

From there, you may have to update the Maven or Gradle build of the project importing DJL to also look at the local maven repository cache for your locally published versions of DJL. For Maven, no changes are necessary. If you are using Gradle, you will have to add the maven local repository such as this [example](https://github.com/deepjavalibrary/djl-demo/blob/135c969d66d98d1672852e53a37e52ca1da3e325/pneumonia-detection/build.gradle#L11):

Expand All @@ -101,3 +111,5 @@ repositories {
mavenLocal()
}
```

Note that `mavenCentral()` may still be needed for applications like log4j and json.
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,12 @@ public void set(Buffer data) {
JnaUtils.syncCopyFromCPU(getHandle(), buf, size);
}

/** {@inheritDoc} */
@Override
public NDArray gather(NDArray index, int axis) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public void copyTo(NDArray ndArray) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,15 @@ public NDArray get(long... indices) {
return JniUtils.getItem(this, indices);
}

/** {@inheritDoc} */
@Override
public NDArray gather(NDArray index, int axis) {
if (!(index instanceof PtNDArray)) {
throw new IllegalArgumentException("Only PtNDArray is supported.");
}
return JniUtils.gather(this, (PtNDArray) index, axis);
}

/** {@inheritDoc} */
@Override
public void copyTo(NDArray array) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,12 @@ public static void set(PtNDArray self, ByteBuffer data) {
PyTorchLibrary.LIB.torchSet(self.getHandle(), data);
}

public static PtNDArray gather(PtNDArray ndArray, PtNDArray index, long dim) {
return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false));
}

public static PtNDArray pick(PtNDArray ndArray, PtNDArray index, long dim) {
Shape indexShape = index.getShape();
Shape ndShape = ndArray.getShape();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,12 @@ public void set(Buffer data) {
JavacppUtils.setByteBuffer(getHandle(), buf);
}

/** {@inheritDoc} */
@Override
public NDArray gather(NDArray index, int axis) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public void attach(NDManager manager) {
Expand Down
2 changes: 1 addition & 1 deletion gradle/wrapper/gradle-wrapper.properties
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-7.4.1-bin.zip
distributionUrl=https\://services.gradle.org/distributions/gradle-7.4.2-bin.zip
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.Shape;
import ai.djl.testing.TestRequirements;
import org.testng.Assert;
import org.testng.annotations.Test;

Expand Down Expand Up @@ -52,6 +53,22 @@ public void testPick() {
}
}

@Test
public void testGather() {
// Currently in windows gradle cannot find all the engines.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment still valid?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably. It seems like something is wrong with runtimeOnly dependencies on windows gradle.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes after updating gradle, on GitHub the test stills fails for the same reason

// In the dependencies, changing runtimeOnly to api however will remedy the problem.
// TODO: remove this when gradle problem is fixed.
TestRequirements.notWindows();
try (NDManager manager = NDManager.newBaseManager()) {
NDArray arr = manager.arange(20f).reshape(-1, 4);
long[] idx = {0, 0, 2, 1, 1, 2};
NDArray index = manager.create(idx, new Shape(3, 2));
NDArray actual = arr.gather(index, 1);
NDArray expected = manager.create(new float[] {0, 0, 6, 5, 9, 10}, new Shape(3, 2));
Assert.assertEquals(actual, expected);
}
}

@Test
public void testGet() {
try (NDManager manager = NDManager.newBaseManager()) {
Expand Down