Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements NDScope to automatically close NDArray in the scope #2321

Merged
merged 10 commits into from
Jan 26, 2023
10 changes: 10 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -4724,6 +4724,13 @@ default NDArray countNonzero(int axis) {
*/
NDArrayEx getNDArrayInternal();

/**
* Returns {@code true} if this NDArray has been released.
*
* @return {@code true} if this NDArray has been released
*/
boolean isReleased();

/**
* Runs the debug string representation of this {@code NDArray}.
*
Expand Down Expand Up @@ -4755,6 +4762,9 @@ default String toDebugString(boolean withContent) {
*/
default String toDebugString(
int maxSize, int maxDepth, int maxRows, int maxColumns, boolean withContent) {
if (isReleased()) {
return "This array is already closed";
}
return NDFormat.format(this, maxSize, maxDepth, maxRows, maxColumns, withContent);
}

Expand Down
6 changes: 6 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,12 @@ public NDArrayEx getNDArrayInternal() {
return array.getNDArrayInternal();
}

/** {@inheritDoc} */
@Override
public boolean isReleased() {
return isClosed;
}

/** {@inheritDoc} */
@Override
public void close() {
Expand Down
82 changes: 82 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDScope.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.ndarray;

import java.util.ArrayDeque;
import java.util.Deque;
import java.util.IdentityHashMap;

/**
* A class that tracks {@link NDResource} objects created in the try-with-resource block and close
* them automatically when out of the block scope.
*
* <p>This class has been derived from {@code org.bytedeco.javacpp.PointerScope} by Samuel Audet
*/
public class NDScope implements AutoCloseable {

private static final ThreadLocal<Deque<NDScope>> SCOPE_STACK =
ThreadLocal.withInitial(ArrayDeque::new);

private IdentityHashMap<NDArray, NDArray> resources;

/** Constructs a new {@code NDScope} instance. */
public NDScope() {
resources = new IdentityHashMap<>();
SCOPE_STACK.get().addLast(this);
}

/**
* Registers {@link NDArray} object to this scope.
*
* @param array the {@link NDArray} object
*/
public static void register(NDArray array) {
Deque<NDScope> queue = SCOPE_STACK.get();
if (queue.isEmpty()) {
return;
}
queue.getLast().resources.put(array, array);
}

/**
* Unregisters {@link NDArray} object from this scope.
*
* @param array the {@link NDArray} object
*/
public static void unregister(NDArray array) {
Deque<NDScope> queue = SCOPE_STACK.get();
if (queue.isEmpty()) {
return;
}
queue.getLast().resources.remove(array);
}

/** {@inheritDoc} */
@Override
public void close() {
for (NDArray array : resources.keySet()) {
array.close();
}
SCOPE_STACK.get().remove(this);
}

/**
* A method that does nothing.
*
* <p>You may use it if you do not have a better way to suppress the warning of a created but
* not explicitly used scope.
*/
public void suppressNotUsedWarning() {
// do nothing
}
}
48 changes: 48 additions & 0 deletions api/src/test/java/ai/djl/ndarray/NDScopeTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.ndarray;

import org.testng.Assert;
import org.testng.annotations.Test;

public class NDScopeTest {

@Test
@SuppressWarnings("try")
public void testNDScope() {
NDArray detached;
NDArray inside;
NDArray uninvolved;
try (NDManager manager = NDManager.newBaseManager()) {
try (NDScope scope = new NDScope()) {
scope.suppressNotUsedWarning();
try (NDScope ignore = new NDScope()) {
uninvolved = manager.create(new int[] {1});
uninvolved.close();
inside = manager.create(new int[] {1});
// not tracked by any NDScope, but still managed by NDManager
NDScope.unregister(inside);
}

detached = manager.create(new int[] {1});
detached.detach(); // detached from NDManager
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. In most cases, when NDArray.detach() is called, user want to return the NDArray from the function.
  2. And usually, NDManager is at the out side if NDScope

Which means user has to call both function in most of time. And we documented that once NDArray is detached, user must manually close it.

So I feel we should unregister from the NDScope when NDArray.detach() is called.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we are abusing NDArray.detach() in many internal place (attache(), they should only detach from NDManager only. Introducing two detach(boolean) function seems too much. Requires user to call both should be fine for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is your intention to introduce NDScope?

Copy link
Contributor Author

@enpasos enpasos Jan 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My intention would be: Implement it because you need it to close a memory leak that you cannot solve otherwise. So I would not use NDScope if I did not have to. But if I have a leak - which I unfortunately do - I would put the NDScope where I think no NDArray should leak through - no matter if inside or outside of NDManager and no matter how they are used. And in that case I would not want NDArrays to leak through for whatever reason.

Copy link
Contributor Author

@enpasos enpasos Jan 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we are abusing NDArray.detach() in many internal place (attache(), they should only detach from NDManager only. Introducing two detach(boolean) function seems too much. Requires user to call both should be fine for now.

I agree that if there was a convenient method where the user knew exactly that he/she was also unregistering from NDScope, there should be no problem.

Copy link
Contributor Author

@enpasos enpasos Jan 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we leave it as it is for this PR?

NDScope.unregister(detached); // and unregistered from NDScope
}

Assert.assertFalse(inside.isReleased());
}
Assert.assertTrue(inside.isReleased());
Assert.assertFalse(detached.isReleased());
detached.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
Expand Down Expand Up @@ -89,6 +90,7 @@ public class MxNDArray extends NativeResource<Pointer> implements LazyNDArray {
this.manager = manager;
mxNDArrayEx = new MxNDArrayEx(this);
manager.attachInternal(getUid(), this);
NDScope.register(this);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
Expand Down Expand Up @@ -64,6 +65,7 @@ public PtNDArray(PtNDManager manager, long handle) {
this.manager = manager;
this.ptNDArrayEx = new PtNDArrayEx(this);
manager.attachInternal(getUid(), this);
NDScope.register(this);
}

/**
Expand All @@ -80,6 +82,7 @@ public PtNDArray(PtNDManager manager, long handle, ByteBuffer data) {
this.ptNDArrayEx = new PtNDArrayEx(this);
manager.attachInternal(getUid(), this);
dataRef = data;
NDScope.register(this);
}

/**
Expand All @@ -96,6 +99,7 @@ public PtNDArray(PtNDManager manager, String[] strs, Shape shape) {
this.strs = strs;
this.shape = shape;
this.dataType = DataType.STRING;
NDScope.register(this);
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
Expand Down Expand Up @@ -55,6 +56,7 @@ public class TfNDArray extends NativeResource<TFE_TensorHandle> implements NDArr
this.manager = manager;
manager.attachInternal(getUid(), this);
tfNDArrayEx = new TfNDArrayEx(this);
NDScope.register(this);
}

TfNDArray(TfNDManager manager, TFE_TensorHandle handle, TF_Tensor tensor) {
Expand Down