diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index cc9844f5ee1..e700c6375e4 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -4724,6 +4724,13 @@ default NDArray countNonzero(int axis) { */ NDArrayEx getNDArrayInternal(); + /** + * Returns {@code true} if this NDArray has been released. + * + * @return {@code true} if this NDArray has been released + */ + boolean isReleased(); + /** * Runs the debug string representation of this {@code NDArray}. * @@ -4755,6 +4762,9 @@ default String toDebugString(boolean withContent) { */ default String toDebugString( int maxSize, int maxDepth, int maxRows, int maxColumns, boolean withContent) { + if (isReleased()) { + return "This array is already closed"; + } return NDFormat.format(this, maxSize, maxDepth, maxRows, maxColumns, withContent); } diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java index fac29be9ab1..fcad8dff88f 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java @@ -1177,6 +1177,12 @@ public NDArrayEx getNDArrayInternal() { return array.getNDArrayInternal(); } + /** {@inheritDoc} */ + @Override + public boolean isReleased() { + return isClosed; + } + /** {@inheritDoc} */ @Override public void close() { diff --git a/api/src/main/java/ai/djl/ndarray/NDScope.java b/api/src/main/java/ai/djl/ndarray/NDScope.java new file mode 100644 index 00000000000..b705d08c0bf --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/NDScope.java @@ -0,0 +1,82 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray; + +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.IdentityHashMap; + +/** + * A class that tracks {@link NDResource} objects created in the try-with-resource block and close + * them automatically when out of the block scope. + * + *
This class has been derived from {@code org.bytedeco.javacpp.PointerScope} by Samuel Audet
+ */
+public class NDScope implements AutoCloseable {
+
+ private static final ThreadLocal You may use it if you do not have a better way to suppress the warning of a created but
+ * not explicitly used scope.
+ */
+ public void suppressNotUsedWarning() {
+ // do nothing
+ }
+}
diff --git a/api/src/test/java/ai/djl/ndarray/NDScopeTest.java b/api/src/test/java/ai/djl/ndarray/NDScopeTest.java
new file mode 100644
index 00000000000..eef151fec11
--- /dev/null
+++ b/api/src/test/java/ai/djl/ndarray/NDScopeTest.java
@@ -0,0 +1,48 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.ndarray;
+
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+public class NDScopeTest {
+
+ @Test
+ @SuppressWarnings("try")
+ public void testNDScope() {
+ NDArray detached;
+ NDArray inside;
+ NDArray uninvolved;
+ try (NDManager manager = NDManager.newBaseManager()) {
+ try (NDScope scope = new NDScope()) {
+ scope.suppressNotUsedWarning();
+ try (NDScope ignore = new NDScope()) {
+ uninvolved = manager.create(new int[] {1});
+ uninvolved.close();
+ inside = manager.create(new int[] {1});
+ // not tracked by any NDScope, but still managed by NDManager
+ NDScope.unregister(inside);
+ }
+
+ detached = manager.create(new int[] {1});
+ detached.detach(); // detached from NDManager
+ NDScope.unregister(detached); // and unregistered from NDScope
+ }
+
+ Assert.assertFalse(inside.isReleased());
+ }
+ Assert.assertTrue(inside.isReleased());
+ Assert.assertFalse(detached.isReleased());
+ detached.close();
+ }
+}
diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java
index 10dbf7d14db..e62c593fc95 100644
--- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java
+++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java
@@ -19,6 +19,7 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
+import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
@@ -89,6 +90,7 @@ public class MxNDArray extends NativeResource