diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index cc9844f5ee1..e700c6375e4 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -4724,6 +4724,13 @@ default NDArray countNonzero(int axis) { */ NDArrayEx getNDArrayInternal(); + /** + * Returns {@code true} if this NDArray has been released. + * + * @return {@code true} if this NDArray has been released + */ + boolean isReleased(); + /** * Runs the debug string representation of this {@code NDArray}. * @@ -4755,6 +4762,9 @@ default String toDebugString(boolean withContent) { */ default String toDebugString( int maxSize, int maxDepth, int maxRows, int maxColumns, boolean withContent) { + if (isReleased()) { + return "This array is already closed"; + } return NDFormat.format(this, maxSize, maxDepth, maxRows, maxColumns, withContent); } diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java index fac29be9ab1..fcad8dff88f 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java @@ -1177,6 +1177,12 @@ public NDArrayEx getNDArrayInternal() { return array.getNDArrayInternal(); } + /** {@inheritDoc} */ + @Override + public boolean isReleased() { + return isClosed; + } + /** {@inheritDoc} */ @Override public void close() { diff --git a/api/src/main/java/ai/djl/ndarray/NDScope.java b/api/src/main/java/ai/djl/ndarray/NDScope.java new file mode 100644 index 00000000000..b705d08c0bf --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/NDScope.java @@ -0,0 +1,82 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray; + +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.IdentityHashMap; + +/** + * A class that tracks {@link NDResource} objects created in the try-with-resource block and close + * them automatically when out of the block scope. + * + *

This class has been derived from {@code org.bytedeco.javacpp.PointerScope} by Samuel Audet + */ +public class NDScope implements AutoCloseable { + + private static final ThreadLocal> SCOPE_STACK = + ThreadLocal.withInitial(ArrayDeque::new); + + private IdentityHashMap resources; + + /** Constructs a new {@code NDScope} instance. */ + public NDScope() { + resources = new IdentityHashMap<>(); + SCOPE_STACK.get().addLast(this); + } + + /** + * Registers {@link NDArray} object to this scope. + * + * @param array the {@link NDArray} object + */ + public static void register(NDArray array) { + Deque queue = SCOPE_STACK.get(); + if (queue.isEmpty()) { + return; + } + queue.getLast().resources.put(array, array); + } + + /** + * Unregisters {@link NDArray} object from this scope. + * + * @param array the {@link NDArray} object + */ + public static void unregister(NDArray array) { + Deque queue = SCOPE_STACK.get(); + if (queue.isEmpty()) { + return; + } + queue.getLast().resources.remove(array); + } + + /** {@inheritDoc} */ + @Override + public void close() { + for (NDArray array : resources.keySet()) { + array.close(); + } + SCOPE_STACK.get().remove(this); + } + + /** + * A method that does nothing. + * + *

You may use it if you do not have a better way to suppress the warning of a created but + * not explicitly used scope. + */ + public void suppressNotUsedWarning() { + // do nothing + } +} diff --git a/api/src/test/java/ai/djl/ndarray/NDScopeTest.java b/api/src/test/java/ai/djl/ndarray/NDScopeTest.java new file mode 100644 index 00000000000..eef151fec11 --- /dev/null +++ b/api/src/test/java/ai/djl/ndarray/NDScopeTest.java @@ -0,0 +1,48 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray; + +import org.testng.Assert; +import org.testng.annotations.Test; + +public class NDScopeTest { + + @Test + @SuppressWarnings("try") + public void testNDScope() { + NDArray detached; + NDArray inside; + NDArray uninvolved; + try (NDManager manager = NDManager.newBaseManager()) { + try (NDScope scope = new NDScope()) { + scope.suppressNotUsedWarning(); + try (NDScope ignore = new NDScope()) { + uninvolved = manager.create(new int[] {1}); + uninvolved.close(); + inside = manager.create(new int[] {1}); + // not tracked by any NDScope, but still managed by NDManager + NDScope.unregister(inside); + } + + detached = manager.create(new int[] {1}); + detached.detach(); // detached from NDManager + NDScope.unregister(detached); // and unregistered from NDScope + } + + Assert.assertFalse(inside.isReleased()); + } + Assert.assertTrue(inside.isReleased()); + Assert.assertFalse(detached.isReleased()); + detached.close(); + } +} diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index 10dbf7d14db..e62c593fc95 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -19,6 +19,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.NDScope; import ai.djl.ndarray.internal.NDArrayEx; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; @@ -89,6 +90,7 @@ public class MxNDArray extends NativeResource implements LazyNDArray { this.manager = manager; mxNDArrayEx = new MxNDArrayEx(this); manager.attachInternal(getUid(), this); + NDScope.register(this); } /** diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index 26b786f54ac..f77880da789 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -17,6 +17,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.NDScope; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.ndarray.types.SparseFormat; @@ -64,6 +65,7 @@ public PtNDArray(PtNDManager manager, long handle) { this.manager = manager; this.ptNDArrayEx = new PtNDArrayEx(this); manager.attachInternal(getUid(), this); + NDScope.register(this); } /** @@ -80,6 +82,7 @@ public PtNDArray(PtNDManager manager, long handle, ByteBuffer data) { this.ptNDArrayEx = new PtNDArrayEx(this); manager.attachInternal(getUid(), this); dataRef = data; + NDScope.register(this); } /** @@ -96,6 +99,7 @@ public PtNDArray(PtNDManager manager, String[] strs, Shape shape) { this.strs = strs; this.shape = shape; this.dataType = DataType.STRING; + NDScope.register(this); } /** {@inheritDoc} */ diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index 920630b5236..aea57521137 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -18,6 +18,7 @@ import ai.djl.ndarray.NDArrays; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.NDScope; import ai.djl.ndarray.internal.NDArrayEx; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; @@ -55,6 +56,7 @@ public class TfNDArray extends NativeResource implements NDArr this.manager = manager; manager.attachInternal(getUid(), this); tfNDArrayEx = new TfNDArrayEx(this); + NDScope.register(this); } TfNDArray(TfNDManager manager, TFE_TensorHandle handle, TF_Tensor tensor) {