From 5677f11ef86a8d63d959f8216db113070ea9b3a2 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 25 Jan 2023 20:20:31 -0800 Subject: [PATCH 01/10] [api] Fixes toDebugString() IllegalStateException if NDArray is closed --- api/src/main/java/ai/djl/ndarray/NDArray.java | 10 ++++++++++ api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java | 6 ++++++ 2 files changed, 16 insertions(+) diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index cc9844f5ee1..e700c6375e4 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -4724,6 +4724,13 @@ default NDArray countNonzero(int axis) { */ NDArrayEx getNDArrayInternal(); + /** + * Returns {@code true} if this NDArray has been released. + * + * @return {@code true} if this NDArray has been released + */ + boolean isReleased(); + /** * Runs the debug string representation of this {@code NDArray}. * @@ -4755,6 +4762,9 @@ default String toDebugString(boolean withContent) { */ default String toDebugString( int maxSize, int maxDepth, int maxRows, int maxColumns, boolean withContent) { + if (isReleased()) { + return "This array is already closed"; + } return NDFormat.format(this, maxSize, maxDepth, maxRows, maxColumns, withContent); } diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java index fac29be9ab1..fcad8dff88f 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java @@ -1177,6 +1177,12 @@ public NDArrayEx getNDArrayInternal() { return array.getNDArrayInternal(); } + /** {@inheritDoc} */ + @Override + public boolean isReleased() { + return isClosed; + } + /** {@inheritDoc} */ @Override public void close() { From 4f6ab894055d8ddf1338044af1468cf70eb589c8 Mon Sep 17 00:00:00 2001 From: enpasos Date: Thu, 12 Jan 2023 19:08:27 +0100 Subject: [PATCH 02/10] Implements NDScope based on JavaCPP PointerScope --- .../java/ai/djl/ndarray/BaseNDManager.java | 150 +++++++----- .../ai/djl/ndarray/refcount/Deallocator.java | 23 ++ .../refcount/DeallocatorReference.java | 131 ++++++++++ .../ai/djl/ndarray/refcount/RCConfig.java | 39 +++ .../ai/djl/ndarray/refcount/RCException.java | 47 ++++ .../ai/djl/ndarray/refcount/RCObject.java | 225 ++++++++++++++++++ .../java/ai/djl/ndarray/refcount/RCScope.java | 208 ++++++++++++++++ .../ndarray/refcount/ReferenceCounter.java | 39 +++ .../ai/djl/ndarray/refcount/package-info.java | 18 ++ .../main/java/ai/djl/util/NativeResource.java | 92 ++++++- .../java/ai/djl/pytorch/engine/PtNDArray.java | 21 +- .../pytorch/engine/PtNDArrayDeallocator.java | 49 ++++ .../ai/djl/pytorch/refcount/RCObjectTest.java | 55 +++++ .../ai/djl/pytorch/refcount/RCScopeTest.java | 105 ++++++++ .../ai/djl/pytorch/refcount/package-info.java | 18 ++ gradle.properties | 2 +- tools/conf/findbugs-exclude.xml | 12 + 17 files changed, 1164 insertions(+), 70 deletions(-) create mode 100644 api/src/main/java/ai/djl/ndarray/refcount/Deallocator.java create mode 100644 api/src/main/java/ai/djl/ndarray/refcount/DeallocatorReference.java create mode 100644 api/src/main/java/ai/djl/ndarray/refcount/RCConfig.java create mode 100644 api/src/main/java/ai/djl/ndarray/refcount/RCException.java create mode 100644 api/src/main/java/ai/djl/ndarray/refcount/RCObject.java create mode 100644 api/src/main/java/ai/djl/ndarray/refcount/RCScope.java create mode 100644 api/src/main/java/ai/djl/ndarray/refcount/ReferenceCounter.java create mode 100644 api/src/main/java/ai/djl/ndarray/refcount/package-info.java create mode 100644 engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayDeallocator.java create mode 100644 engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/RCObjectTest.java create mode 100644 engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/RCScopeTest.java create mode 100644 engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/package-info.java diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index abbc3c31605..8addec50d04 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -63,6 +63,77 @@ protected BaseNDManager(NDManager parent, Device device) { } } + /** + * Checks if the input buffer size is match expected data type. + * + * @param buffer the input buffer + * @param dataType the desired {@code DataType} + * @param expected the expected size + * @throws IllegalArgumentException if buffer size is invalid + */ + public static void validateBuffer(Buffer buffer, DataType dataType, int expected) { + boolean isByteBuffer = buffer instanceof ByteBuffer; + DataType type = DataType.fromBuffer(buffer); + if (type != dataType && !isByteBuffer) { + // It's ok if type != datatype and buffer is ByteBuffer, + // since buffer will be copied into ByteBuffer + throw new IllegalArgumentException( + "The input data type: " + + type + + " does not match target array data type: " + + dataType); + } + + int remaining = buffer.remaining(); + int expectedSize = isByteBuffer ? dataType.getNumOfBytes() * expected : expected; + if (remaining < expectedSize) { + throw new IllegalArgumentException( + "The NDArray size is: " + expected + ", but buffer size is: " + remaining); + } + if (remaining > expectedSize) { + logger.warn( + "Input buffer size is greater than the NDArray size, please set limit" + + " explicitly."); + buffer.limit(expectedSize); + } + } + + /** + * Copies data from the source {@code Buffer} to the target {@code ByteBuffer}. + * + * @param src the source {@code Buffer} + * @param target the target {@code ByteBuffer} + */ + public static void copyBuffer(Buffer src, ByteBuffer target) { + target.rewind(); + DataType inputType = DataType.fromBuffer(src); + switch (inputType) { + case FLOAT16: + target.asShortBuffer().put((ShortBuffer) src); + break; + case FLOAT32: + target.asFloatBuffer().put((FloatBuffer) src); + break; + case FLOAT64: + target.asDoubleBuffer().put((DoubleBuffer) src); + break; + case UINT8: + case INT8: + case BOOLEAN: + target.put((ByteBuffer) src); + break; + case INT32: + target.asIntBuffer().put((IntBuffer) src); + break; + case INT64: + target.asLongBuffer().put((LongBuffer) src); + break; + default: + throw new AssertionError("Unsupported datatype: " + inputType); + } + target.rewind(); + } + /** {@inheritDoc} */ @Override public final Device defaultDevice() { @@ -107,14 +178,14 @@ public NDList load(Path path) { /** {@inheritDoc} */ @Override - public void setName(String name) { - this.name = name; + public String getName() { + return this.name == null ? uid : this.name; } /** {@inheritDoc} */ @Override - public String getName() { - return this.name == null ? uid : this.name; + public void setName(String name) { + this.name = name; } /** {@inheritDoc} */ @@ -468,74 +539,23 @@ NDManager getAlternativeManager() { } /** - * Checks if the input buffer size is match expected data type. + * Returns true if the resource is a resource of this manager. * - * @param buffer the input buffer - * @param dataType the desired {@code DataType} - * @param expected the expected size - * @throws IllegalArgumentException if buffer size is invalid + * @param resource the resource to check + * @return true if the resource is a resource of this manager */ - public static void validateBuffer(Buffer buffer, DataType dataType, int expected) { - boolean isByteBuffer = buffer instanceof ByteBuffer; - DataType type = DataType.fromBuffer(buffer); - if (type != dataType && !isByteBuffer) { - // It's ok if type != datatype and buffer is ByteBuffer, - // since buffer will be copied into ByteBuffer - throw new IllegalArgumentException( - "The input data type: " - + type - + " does not match target array data type: " - + dataType); - } - - int remaining = buffer.remaining(); - int expectedSize = isByteBuffer ? dataType.getNumOfBytes() * expected : expected; - if (remaining < expectedSize) { - throw new IllegalArgumentException( - "The NDArray size is: " + expected + ", but buffer size is: " + remaining); - } - if (remaining > expectedSize) { - logger.warn( - "Input buffer size is greater than the NDArray size, please set limit" - + " explicitly."); - buffer.limit(expectedSize); - } + public boolean hasResource(AutoCloseable resource) { + return this.resources.values().contains(resource); } /** - * Copies data from the source {@code Buffer} to the target {@code ByteBuffer}. + * Returns true if the resource is a temporary resource. * - * @param src the source {@code Buffer} - * @param target the target {@code ByteBuffer} + * @param resource the resource to check + * @return true if the resource is a temporary resource */ - public static void copyBuffer(Buffer src, ByteBuffer target) { - target.rewind(); - DataType inputType = DataType.fromBuffer(src); - switch (inputType) { - case FLOAT16: - target.asShortBuffer().put((ShortBuffer) src); - break; - case FLOAT32: - target.asFloatBuffer().put((FloatBuffer) src); - break; - case FLOAT64: - target.asDoubleBuffer().put((DoubleBuffer) src); - break; - case UINT8: - case INT8: - case BOOLEAN: - target.put((ByteBuffer) src); - break; - case INT32: - target.asIntBuffer().put((IntBuffer) src); - break; - case INT64: - target.asLongBuffer().put((LongBuffer) src); - break; - default: - throw new AssertionError("Unsupported datatype: " + inputType); - } - target.rewind(); + public boolean hasTempResource(AutoCloseable resource) { + return this.resources.values().contains(resource); } protected static final class TempResource { diff --git a/api/src/main/java/ai/djl/ndarray/refcount/Deallocator.java b/api/src/main/java/ai/djl/ndarray/refcount/Deallocator.java new file mode 100644 index 00000000000..1f180a687b2 --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/refcount/Deallocator.java @@ -0,0 +1,23 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.refcount; + +/** + * The interface to implement to produce a Deallocator usable by referenceCountedObject. + * + *

This interface has been derived from {@code org.bytedeco.javacpp.Pointer} by Samuel Audet + */ +public interface Deallocator { + /** The method to implement to produce a Deallocator usable by referenceCountedObject. */ + void deallocate(); +} diff --git a/api/src/main/java/ai/djl/ndarray/refcount/DeallocatorReference.java b/api/src/main/java/ai/djl/ndarray/refcount/DeallocatorReference.java new file mode 100644 index 00000000000..1b056ededb6 --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/refcount/DeallocatorReference.java @@ -0,0 +1,131 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.refcount; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.ref.PhantomReference; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * A subclass of {@link PhantomReference} that also acts as a linked list to keep their references + * alive until they get garbage collected. Implements reference counting with an {@link + * AtomicInteger} count. + * + *

This class has been derived from {@code org.bytedeco.javacpp.Pointer} by Samuel Audet + */ +@SuppressWarnings({"PMD.MutableStaticState", "PMD.AvoidUsingVolatile"}) +public class DeallocatorReference extends PhantomReference + implements Deallocator, ReferenceCounter { + + private static final Logger logger = LoggerFactory.getLogger(DeallocatorReference.class); + static volatile DeallocatorReference head; + static volatile long totalCount; + volatile DeallocatorReference prev = this; + volatile DeallocatorReference next = this; + protected Deallocator deallocator; + AtomicInteger count; + + protected DeallocatorReference(RCObject p, Deallocator deallocator) { + super(p, null); + this.deallocator = deallocator; + this.count = new AtomicInteger(0); + } + + final void add() { + synchronized (DeallocatorReference.class) { + if (head == null) { + head = this; + prev = next = null; + } else { + prev = null; + next = head; + next.prev = head = this; + } + totalCount++; + } + } + + final void remove() { + if (prev == this && next == this) { + return; + } + synchronized (DeallocatorReference.class) { + if (prev == null) { + head = next; + } else { + prev.next = next; + } + if (next != null) { + next.prev = prev; + } + prev = next = this; + totalCount--; + } + } + + /** {@inheritDoc} */ + @Override + public void clear() { + super.clear(); + if (deallocator != null) { + if (logger.isDebugEnabled()) { + logger.trace("Collecting " + this); + } + deallocate(); + } + } + + /** {@inheritDoc} */ + @Override + public void deallocate() { + if (deallocator != null) { + deallocator.deallocate(); + deallocator = null; + } + } + + /** {@inheritDoc} */ + @Override + public void retain() { + if (deallocator != null) { + count.incrementAndGet(); + } + } + + /** {@inheritDoc} */ + @Override + public boolean release() { + if (deallocator != null && count.decrementAndGet() <= 0) { + if (logger.isDebugEnabled()) { + logger.trace("Releasing " + this); + } + deallocate(); + return true; + } + return false; + } + + /** {@inheritDoc} */ + @Override + public int count() { + return deallocator != null ? count.get() : -1; + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return getClass().getName() + "[deallocator=" + deallocator + ",count=" + count + "]"; + } +} diff --git a/api/src/main/java/ai/djl/ndarray/refcount/RCConfig.java b/api/src/main/java/ai/djl/ndarray/refcount/RCConfig.java new file mode 100644 index 00000000000..b10349a19c1 --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/refcount/RCConfig.java @@ -0,0 +1,39 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.refcount; + +/** RCConfig holds configurable parameters. */ +public final class RCConfig { + + private static boolean verboseIfResourceAlreadyClosed; + + private RCConfig() {} + + /** + * If true, a verbose message is printed if a resource is already closed. + * + * @return true if the verboseIfResourceAlreadyClosed is set + */ + public static boolean isVerboseIfResourceAlreadyClosed() { + return verboseIfResourceAlreadyClosed; + } + + /** + * If true, a verbose message is printed if a resource is already closed. + * + * @param verboseIfResourceAlreadyClosed parameter to set + */ + public static void setVerboseIfResourceAlreadyClosed(boolean verboseIfResourceAlreadyClosed) { + RCConfig.verboseIfResourceAlreadyClosed = verboseIfResourceAlreadyClosed; + } +} diff --git a/api/src/main/java/ai/djl/ndarray/refcount/RCException.java b/api/src/main/java/ai/djl/ndarray/refcount/RCException.java new file mode 100644 index 00000000000..bbdc494eddb --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/refcount/RCException.java @@ -0,0 +1,47 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.refcount; + +/** A runtime exception thrown within the reference counting package. */ +public class RCException extends RuntimeException { + + private static final long serialVersionUID = 1L; + + /** + * Constructs a new runtime exception with the specified detail message. + * + * @param message - the message to be displayed + */ + public RCException(String message) { + super(message); + } + + /** + * Constructs a new runtime exception with the specified detail message and cause. + * + * @param message - the message to be displayed + * @param cause - the cause of the exception + */ + public RCException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Constructs a new runtime exception with the specified cause. + * + * @param cause - the cause of the exception + */ + public RCException(Throwable cause) { + super(cause); + } +} diff --git a/api/src/main/java/ai/djl/ndarray/refcount/RCObject.java b/api/src/main/java/ai/djl/ndarray/refcount/RCObject.java new file mode 100644 index 00000000000..248ba6e04b0 --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/refcount/RCObject.java @@ -0,0 +1,225 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.refcount; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Iterator; + +/** + * All peer classes to native types must be descended from {@link RCObject} (reference counted + * object), the topmost class. + * + *

It is also possible to use a {@link RCScope} to keep track of a group of {@link RCObject} + * objects, and have them deallocated in a transparent but deterministic manner. + * + *

This class has been derived from {@code org.bytedeco.javacpp.Pointer} by Samuel Audet + */ +@SuppressWarnings("PMD.AvoidBranchingStatementAsLastInLoop") +public class RCObject { + + private static final Logger logger = LoggerFactory.getLogger(RCObject.class); + + private Deallocator deallocator; + + /** + * Returns {@link DeallocatorReference#totalCount}, current number of ReferenceCountedObjects + * tracked by deallocators. + * + * @return the total count + */ + public static long totalCount() { + return DeallocatorReference.totalCount; + } + + /** + * Returns {@link ReferenceCounter#count()} or -1 if no deallocator has been set. + * + * @return the count + */ + public int referenceCount() { + ReferenceCounter r = (ReferenceCounter) deallocator; + return r != null ? r.count() : -1; + } + + /** + * Returns {@link #deallocator}. + * + * @return the deallocator + */ + protected Deallocator deallocator() { + return deallocator; + } + + /** + * Sets the deallocator and returns this. Also clears current deallocator if not {@code null}. + * That is, it deallocates previously allocated memory. Should not be called more than once + * after allocation. + * + * @param deallocator the new deallocator + * @param

the type of the referenceCountedObject + * @return this referenceCountedObject + */ + protected

P deallocator(Deallocator deallocator) { + if (this.deallocator != null) { + if (logger.isDebugEnabled()) { + logger.debug("Predeallocating " + this); + } + this.deallocator.deallocate(); + this.deallocator = null; + } + if (deallocator != null) { + DeallocatorReference r = + deallocator instanceof DeallocatorReference + ? (DeallocatorReference) deallocator + : new DeallocatorReference(this, deallocator); + this.deallocator = r; + Iterator it = RCScope.getScopeIterator(); + if (it != null) { + while (it.hasNext()) { + try { + it.next().attach(this); + } catch (IllegalArgumentException e) { + // try the next scope down the stack + continue; + } + break; + } + } + } + + @SuppressWarnings("unchecked") + P p = (P) this; + return p; + } + + /** Calls {@code deallocate(true)}. */ + public void deallocate() { + deallocate(true); + } + + /** + * Explicitly manages native memory without waiting after the garbage collector. Has no effect + * if no deallocator was previously set with {@link #deallocator(Deallocator)}. + * + * @param deallocate if true, deallocates, else does not, but disables garbage collection + */ + public void deallocate(boolean deallocate) { + DeallocatorReference r = (DeallocatorReference) deallocator; + if (deallocate && deallocator != null) { + if (logger.isDebugEnabled()) { + logger.debug("Deallocating " + this); + } + deallocator.deallocate(); + deallocator = null; + // address = 0; + } + if (r != null) { + // remove from queue without calling the deallocator + r.deallocator = null; + r.clear(); + r.remove(); + r.deallocator = deallocator; + } + } + + /** + * Calls {@link ReferenceCounter#retain()}, incrementing the reference count by 1. Has no effect + * if no deallocator was previously set with {@link #deallocator(Deallocator)}. + * + * @param

the type of the referenceCountedObject + * @return this + */ + public

P retainReference() { + ReferenceCounter r = (ReferenceCounter) deallocator; + if (r != null) { + r.retain(); + } + @SuppressWarnings("unchecked") + P p = (P) this; + return p; + } + + /** + * Calls {@link ReferenceCounter#release()}, decrementing the reference count by 1, in turn + * deallocating this referenceCountedObject when the count drops to 0. Has no effect if no + * deallocator was previously set with {@link #deallocator(Deallocator)}. + * + * @return true when the count drops to 0 and deallocation has occurred + */ + public boolean releaseReference() { + DeallocatorReference r = (DeallocatorReference) deallocator; + if (r != null && r.release()) { + deallocator = null; + // address = 0; + r.clear(); + r.remove(); + return true; + } + return false; + } + + /** + * Calls in effect {@code memcpy(this.address + this.position, p.address + p.position, length)}, + * where {@code length = sizeof(p) * (p.limit - p.position)}. If limit == 0, it uses position + + * 1 instead. The way the methods were designed allows constructs such as {@code + * this.position(0).put(p.position(13).limit(42))}. + * + * @param p the referenceCountedObject from which to copy memory + * @param

the type of the referenceCountedObject + * @return this + */ + public

P put(RCObject p) { + @SuppressWarnings("unchecked") + P p2 = (P) this; + return p2; + } + + /** + * Calls in effect {@code memset(address + position, b, length)}, where {@code length = sizeof() + * * (limit - position)}. If limit == 0, it uses position + 1 instead. The way the methods were + * designed allows constructs such as {@code this.position(0).limit(13).fill(42)}; + * + * @param b the byte value to fill the memory with + * @param

the type of the referenceCountedObject + * @return this + */ + public

P fill(int b) { + @SuppressWarnings("unchecked") + P p = (P) this; + return p; + } + + /** + * Returns {@code fill(0)}. + * + * @param

the type of the referenceCountedObject + * @return this + */ + public

P zero() { + // repair warning: [unchecked] unchecked cast + @SuppressWarnings("unchecked") + P p = (P) this.fill(0); + return p; + } + + /** + * Returns whether the resource is null. + * + * @return whether the resource is null + */ + public boolean isNull() { + throw new UnsupportedOperationException("Not implemented."); + } +} diff --git a/api/src/main/java/ai/djl/ndarray/refcount/RCScope.java b/api/src/main/java/ai/djl/ndarray/refcount/RCScope.java new file mode 100644 index 00000000000..11389f5bbec --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/refcount/RCScope.java @@ -0,0 +1,208 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.refcount; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.Deque; +import java.util.Iterator; + +/** + * {@link RCObject} objects attach themselves automatically on creation to the first {@link RCScope} + * found in {@link #SCOPE_STACK} that they can to based on the classes found in {@link #forClasses}. + * The user can then call {@link #deallocate()}, or rely on {@link #close()} to release in a timely + * fashion all attached referenceCountedObject objects, instead of relying on the garbage collector. + * + *

This class has been derived from {@code org.bytedeco.javacpp.PointerScope} by Samuel Audet + */ +public class RCScope implements AutoCloseable { + /** + * A thread-local stack of {@link RCScope} objects. referenceCountedObject objects attach + * themselves automatically to the first one they can to on the stack. + */ + static final ThreadLocal> SCOPE_STACK = + new ThreadLocal>() { + @Override + protected Deque initialValue() { + return new ArrayDeque(); + } + }; + + private static final Logger logger = LoggerFactory.getLogger(RCScope.class); + /** The stack keeping references to attached {@link RCObject} objects. */ + Deque referenceCountedObjectStack = new ArrayDeque<>(); + /** When not empty, indicates the classes of objects that are allowed to be attached. */ + Class[] forClasses; + /** + * When set to true, the next call to {@link #close()} does not release but resets this + * variable. + */ + boolean extend; + + /** + * Creates a new scope accepting all referenceCountedObject types and pushes itself on the + * {@link #SCOPE_STACK}. + */ + public RCScope() { + this((Class[]) null); + } + + /** + * Initializes {@link #forClasses}, and adds itself as first (push) on the {@link #SCOPE_STACK}. + * + * @param forClasses the classes of objects that are allowed to be attached + */ + @SafeVarargs + @SuppressWarnings("varargs") + public RCScope(Class... forClasses) { + if (logger.isDebugEnabled()) { + logger.debug("Opening " + this); + } + this.forClasses = forClasses; + SCOPE_STACK.get().addFirst(this); + } + + /** + * Returns {@code SCOPE_STACK.get().peekFirst()} (peek), the last opened scope not yet closed. + * + * @return the last opened scope not yet closed + */ + public static RCScope getInnerScope() { + return SCOPE_STACK.get().peekFirst(); + } + + /** + * Returns {@code SCOPE_STACK.get().iterator()}, all scopes not yet closed. + * + * @return all scopes not yet closed + */ + public static Iterator getScopeIterator() { + return SCOPE_STACK.get().iterator(); + } + + /** + * When not empty, returns the classes of objects that are allowed to be attached. + * + * @return the classes of objects that are allowed to be attached + */ + public Class[] forClasses() { + return forClasses; + } + + /** + * Pushes the referenceCountedObject onto the {@link #referenceCountedObjectStack} of this Scope + * and calls {@link RCObject#retainReference()}. + * + * @param p the referenceCountedObject to attach + * @return the referenceCountedObject + * @throws IllegalArgumentException when it is not an instance of a class in {@link + * #forClasses}. + */ + public RCScope attach(RCObject p) { + if (logger.isDebugEnabled()) { + logger.debug("Attaching " + p + " to " + this); + } + if (forClasses != null && forClasses.length > 0) { + boolean found = false; + for (Class c : forClasses) { + if (c != null && c.isInstance(p)) { + found = true; + break; + } + } + if (!found) { + throw new IllegalArgumentException( + p + + " is not an instance of a class in forClasses: " + + Arrays.toString(forClasses)); + } + } + referenceCountedObjectStack.push(p); + p.retainReference(); + return this; + } + + /** + * Removes the referenceCountedObject from the {@link #referenceCountedObjectStack} of this + * Scope and calls {@link RCObject#releaseReference()}. + * + * @param p the referenceCountedObject to detach + * @return the referenceCountedObject + */ + public RCScope detach(RCObject p) { + if (logger.isDebugEnabled()) { + logger.debug("Detaching " + p + " from " + this); + } + referenceCountedObjectStack.remove(p); + p.releaseReference(); + return this; + } + + /** + * Extends the life of this scope past the next call to {@link #close()} by setting the {@link + * #extend} flag. + * + * @return this scope + */ + public RCScope extend() { + if (logger.isDebugEnabled()) { + logger.debug("Extending " + this); + } + extend = true; + return this; + } + + /** + * Pops from {@link #referenceCountedObjectStack} all attached ReferenceCountedObjects, calls + * {@link RCObject#releaseReference()} on them, unless extended, in which case it only resets + * the {@link #extend} flag instead, and finally removes itself from {@link #SCOPE_STACK}. + */ + @Override + public void close() { + if (logger.isDebugEnabled()) { + logger.debug("Closing " + this); + } + if (extend) { + extend = false; + } else { + while (referenceCountedObjectStack.size() > 0) { + referenceCountedObjectStack.pop().releaseReference(); + } + } + SCOPE_STACK.get().remove(this); + } + + /** + * Pops from {@link #referenceCountedObjectStack} all attached ReferenceCountedObjects, and + * calls {@link RCObject#deallocate()} on them. + */ + public void deallocate() { + if (logger.isDebugEnabled()) { + logger.debug("Deallocating " + this); + } + while (referenceCountedObjectStack.size() > 0) { + referenceCountedObjectStack.pop().deallocate(); + } + } + + /** + * A method that does nothing. You may use it if you do not have a better way to suppress the + * warning of a created but not explicitly used scope. + */ + public void suppressNotUsedWarning() { + // do nothing + } +} diff --git a/api/src/main/java/ai/djl/ndarray/refcount/ReferenceCounter.java b/api/src/main/java/ai/djl/ndarray/refcount/ReferenceCounter.java new file mode 100644 index 00000000000..25ef134fba1 --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/refcount/ReferenceCounter.java @@ -0,0 +1,39 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.refcount; + +/** + * The ReferenceCounter interface. + * + *

This interface has been derived from {@code org.bytedeco.javacpp.Pointer} by Samuel Audet + */ +public interface ReferenceCounter { + + /** Increments the reference count by 1 starting from initially 0. */ + void retain(); + + /** + * Decrements the reference count by 1, in turn deallocating this Pointer when the count drops + * to 0. + * + * @return true when the count drops to 0 and deallocation has occurred + */ + boolean release(); + + /** + * Returns the count value. + * + * @return the count value + */ + int count(); +} diff --git a/api/src/main/java/ai/djl/ndarray/refcount/package-info.java b/api/src/main/java/ai/djl/ndarray/refcount/package-info.java new file mode 100644 index 00000000000..cbee2abc6d1 --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/refcount/package-info.java @@ -0,0 +1,18 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** + * Contains a reference counting implementation derived from JavaCPP's Pointer and PointerScope + * helping to avoid memory leaks in {@link ai.djl.ndarray.NDArray}. + */ +package ai.djl.ndarray.refcount; diff --git a/api/src/main/java/ai/djl/util/NativeResource.java b/api/src/main/java/ai/djl/util/NativeResource.java index 65aa8f0085a..c0a12a0dd46 100644 --- a/api/src/main/java/ai/djl/util/NativeResource.java +++ b/api/src/main/java/ai/djl/util/NativeResource.java @@ -12,8 +12,15 @@ */ package ai.djl.util; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.refcount.RCConfig; +import ai.djl.ndarray.refcount.RCObject; + import com.sun.jna.Pointer; +import java.text.MessageFormat; +import java.time.Instant; +import java.util.Arrays; import java.util.concurrent.atomic.AtomicReference; /** @@ -22,14 +29,74 @@ * * @param the resource that could map to a native pointer or java object */ -public abstract class NativeResource implements AutoCloseable { +@SuppressWarnings("PMD.ConstructorCallsOverridableMethod") +public abstract class NativeResource extends RCObject implements AutoCloseable { protected final AtomicReference handle; - private String uid; + protected Instant creationTime; + + protected String uid; + private String creationStackTraceAsString; + private String closingStackTraceAsString; + + /** Constructs a new {@code NativeResource}. */ + public NativeResource() { + // super(); + handle = new AtomicReference<>(); + this.creationTime = Instant.now(); + if (RCConfig.isVerboseIfResourceAlreadyClosed()) { + creationStackTraceAsString = stackTraceAsString(); + } + } protected NativeResource(T handle) { this.handle = new AtomicReference<>(handle); uid = handle.toString(); + this.creationTime = Instant.now(); + if (RCConfig.isVerboseIfResourceAlreadyClosed()) { + creationStackTraceAsString = stackTraceAsString(); + } + } + + private String fingerPrintOfNativeResourceWithStackTraceFromCreation() { + String name = "NO_NAME"; + if (this instanceof NDArray) { + name = ((NDArray) this).getName(); + } + return MessageFormat.format( + "NDArray named \"{0}\" identified by (uid:{1};createdAt:{2}) \n" + + "call stack at creation...{3}\n" + + "######### \n" + + "call stack at closing...{4}\n" + + "#########", + name, + getUid(), + creationTime, + creationStackTraceAsString, + closingStackTraceAsString); + } + + /** + * Returns the current stack trace as a string. + * + * @return the current stack trace as a string + */ + public static String stackTraceAsString() { + StringBuilder buf = new StringBuilder(); + Arrays.stream(Thread.currentThread().getStackTrace()) + .forEach( + s -> + buf.append( + "\nat " + + s.getClassName() + + "." + + s.getMethodName() + + "(" + + s.getFileName() + + ":" + + s.getLineNumber() + + ")")); + return buf.toString(); } /** @@ -49,7 +116,11 @@ public boolean isReleased() { public T getHandle() { T reference = handle.get(); if (reference == null) { - throw new IllegalStateException("Native resource has been release already."); + String message = "Native resource has been released already. "; + if (RCConfig.isVerboseIfResourceAlreadyClosed()) { + message += fingerPrintOfNativeResourceWithStackTraceFromCreation(); + } + throw new IllegalStateException(message); } return reference; } @@ -63,6 +134,21 @@ public final String getUid() { return uid; } + /** + * Sets the closingStackTraceAsString. + * + * @param closingStackTraceAsString the closingStackTraceAsString to set + */ + protected void setClosingStackTraceAsString(String closingStackTraceAsString) { + this.closingStackTraceAsString = closingStackTraceAsString; + } + + /** {@inheritDoc} */ + @Override + public boolean isNull() { + return handle.get() == null; + } + /** {@inheritDoc} */ @Override public void close() { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index 26b786f54ac..d06e183a48f 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -17,6 +17,8 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.refcount.RCConfig; +import ai.djl.ndarray.refcount.RCObject; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.ndarray.types.SparseFormat; @@ -64,6 +66,7 @@ public PtNDArray(PtNDManager manager, long handle) { this.manager = manager; this.ptNDArrayEx = new PtNDArrayEx(this); manager.attachInternal(getUid(), this); + deallocator(new PtNDArrayDeallocator(this)); } /** @@ -76,6 +79,7 @@ public PtNDArray(PtNDManager manager, long handle) { */ public PtNDArray(PtNDManager manager, long handle, ByteBuffer data) { super(handle); + deallocator(new PtNDArrayDeallocator(this)); this.manager = manager; this.ptNDArrayEx = new PtNDArrayEx(this); manager.attachInternal(getUid(), this); @@ -92,12 +96,22 @@ public PtNDArray(PtNDManager manager, long handle, ByteBuffer data) { */ public PtNDArray(PtNDManager manager, String[] strs, Shape shape) { super(-1L); + deallocator(new PtNDArrayDeallocator(this)); this.manager = manager; this.strs = strs; this.shape = shape; this.dataType = DataType.STRING; } + /** + * Deallocates the native memory associated with the specified {@link RCObject}. + * + * @param rco the reference count object + */ + public static void deallocate(RCObject rco) { + ((PtNDArray) rco).close(); + } + /** {@inheritDoc} */ @Override public PtNDManager getManager() { @@ -1591,11 +1605,16 @@ public int hashCode() { /** {@inheritDoc} */ @Override public void close() { + if (RCConfig.isVerboseIfResourceAlreadyClosed()) { + setClosingStackTraceAsString(stackTraceAsString()); + } Long pointer = handle.getAndSet(null); if (pointer != null && pointer != -1) { JniUtils.deleteNDArray(pointer); } - manager.detachInternal(getUid()); + if (manager != null) { + manager.detachInternal(getUid()); + } dataRef = null; } } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayDeallocator.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayDeallocator.java new file mode 100644 index 00000000000..97590a7c31a --- /dev/null +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayDeallocator.java @@ -0,0 +1,49 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.pytorch.engine; + +import ai.djl.ndarray.refcount.Deallocator; +import ai.djl.ndarray.refcount.DeallocatorReference; + +/** + * A {@link Deallocator} that calls, during garbage collection, the method {@link + * PtNDArray#deallocate()} from the referenceCountedObject of type {@link PtNDArray}. + * + *

This class has been derived from {@code org.bytedeco.javacpp.Pointer} by Samuel Audet + */ +public class PtNDArrayDeallocator extends DeallocatorReference { + PtNDArray referenceCountedObject; + + /** + * Constructs and initializes a {@code PtNDArrayDeallocator} with a {@link PtNDArray} to. + * + * @param p - the {@link PtNDArray} to be deallocated + */ + public PtNDArrayDeallocator(PtNDArray p) { + super(p, null); + this.deallocator = this; + this.referenceCountedObject = p; + } + + /** {@inheritDoc} */ + @Override + public void deallocate() { + PtNDArray.deallocate(referenceCountedObject); + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return getClass().getName() + "[referenceCountedObject=" + referenceCountedObject + "]"; + } +} diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/RCObjectTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/RCObjectTest.java new file mode 100644 index 00000000000..43b5d36e7c3 --- /dev/null +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/RCObjectTest.java @@ -0,0 +1,55 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.pytorch.refcount; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.refcount.RCObject; + +import org.testng.Assert; +import org.testng.annotations.Test; + +public class RCObjectTest { + + @Test + public void testNDArraySimpleLifecycle() { + System.out.println("NDArray simple lifecycle"); + try (NDManager manager = NDManager.newBaseManager()) { + NDArray array1 = manager.create(new int[] {1, 2, 3}); + RCObject rco1 = (RCObject) array1; + Assert.assertEquals(rco1.referenceCount(), 0); + rco1.retainReference(); + Assert.assertEquals(rco1.referenceCount(), 1); + Assert.assertTrue(rco1.releaseReference()); + Assert.assertEquals(rco1.referenceCount(), -1); + } + } + + @Test + public void testNDArraySimpleLifecycle2() { + System.out.println("NDArray simple lifecycle 2"); + try (NDManager manager = NDManager.newBaseManager()) { + NDArray array1 = manager.create(new int[] {1, 2, 3}); + RCObject rco1 = (RCObject) array1; + Assert.assertEquals(rco1.referenceCount(), 0); + rco1.retainReference(); + Assert.assertEquals(rco1.referenceCount(), 1); + rco1.retainReference(); + Assert.assertEquals(rco1.referenceCount(), 2); + Assert.assertFalse(rco1.releaseReference()); + Assert.assertEquals(rco1.referenceCount(), 1); + Assert.assertTrue(rco1.releaseReference()); + Assert.assertEquals(rco1.referenceCount(), -1); + } + } +} diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/RCScopeTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/RCScopeTest.java new file mode 100644 index 00000000000..67261198946 --- /dev/null +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/RCScopeTest.java @@ -0,0 +1,105 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.pytorch.refcount; + +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.refcount.RCObject; +import ai.djl.ndarray.refcount.RCScope; +import ai.djl.pytorch.engine.PtNDArray; + +import org.testng.Assert; +import org.testng.annotations.Test; + +public class RCScopeTest { + + @Test + public void testRCScopeBaseProperties() { + System.out.println("RCScope base properties"); + + try (NDManager manager = NDManager.newBaseManager()) { + RCObject outside = (RCObject) manager.create(new int[] {1}); + RCObject attached = (RCObject) manager.create(new int[] {1}); + RCObject detached; + RCObject inside; + RCObject inside1; + RCObject inside2; + RCObject retained1; + RCObject retained2; + RCObject inside5; + + try (RCScope scope = new RCScope()) { + scope.attach(attached); + + detached = (RCObject) manager.create(new int[] {1}); + detached.retainReference(); + scope.detach(detached); + + inside = (RCObject) manager.create(new int[] {1}); + try (RCScope scope1 = new RCScope()) { + scope1.suppressNotUsedWarning(); + inside1 = (RCObject) manager.create(new int[] {1}); + inside2 = (RCObject) manager.create(new int[] {1}); + } + try (RCScope scope2 = new RCScope()) { + scope2.suppressNotUsedWarning(); + retained1 = (RCObject) manager.create(new int[] {1}); + retained2 = (RCObject) manager.create(new int[] {1}); + retained1.retainReference(); + scope.attach(retained2); + } + retained2.retainReference(); + inside5 = (RCObject) manager.create(new int[] {1}); + } + + RCObject outside2 = (RCObject) manager.create(new int[] {1}); + + Assert.assertFalse(outside.isNull()); + Assert.assertTrue(attached.isNull()); + Assert.assertFalse(detached.isNull()); + Assert.assertTrue(inside.isNull()); + Assert.assertTrue(inside1.isNull()); + Assert.assertTrue(inside2.isNull()); + Assert.assertFalse(retained1.isNull()); + Assert.assertFalse(retained2.isNull()); + Assert.assertTrue(inside5.isNull()); + Assert.assertFalse(outside2.isNull()); + + outside.releaseReference(); + detached.releaseReference(); + retained1.releaseReference(); + retained2.releaseReference(); + outside2.releaseReference(); + + Assert.assertTrue(outside.isNull()); + Assert.assertTrue(detached.isNull()); + Assert.assertTrue(retained1.isNull()); + Assert.assertTrue(retained2.isNull()); + Assert.assertTrue(outside2.isNull()); + } + } + + @Test + public void testRCScopeDetachingFromManager() { + System.out.println("RCScope detaching from manager"); + PtNDArray inside; + + try (NDManager manager = NDManager.newBaseManager()) { + try (RCScope scope = new RCScope()) { + scope.suppressNotUsedWarning(); + inside = (PtNDArray) manager.create(new int[] {1}); + } + Assert.assertFalse(inside.getManager().hasResource(inside)); + Assert.assertFalse(inside.getManager().hasTempResource(inside)); + } + } +} diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/package-info.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/package-info.java new file mode 100644 index 00000000000..1581cabfd1b --- /dev/null +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/package-info.java @@ -0,0 +1,18 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** + * Contains test for the reference counting implementation derived from JavaCPP's Pointer and PointerScope. + * The test are derived from the JavaCPP's PointerTest. + */ +package ai.djl.pytorch.refcount; diff --git a/gradle.properties b/gradle.properties index ed47e5387ec..35c443ce24d 100644 --- a/gradle.properties +++ b/gradle.properties @@ -43,6 +43,6 @@ tablesaw_version=0.43.1 spark_version=3.2.2 antlr_version=4.9.3 -testng_version=7.7.0 +testng_version=7.7.1 junit_version=4.13.2 mockito_version=4.8.0 diff --git a/tools/conf/findbugs-exclude.xml b/tools/conf/findbugs-exclude.xml index b36584a1714..c54c777e2d7 100644 --- a/tools/conf/findbugs-exclude.xml +++ b/tools/conf/findbugs-exclude.xml @@ -33,4 +33,16 @@ + + + + + + + + + + + + From 6ae6600782770b29b1552141cec14c0af09fd778 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 25 Jan 2023 20:44:51 -0800 Subject: [PATCH 03/10] Simplify RCScope implementation --- .../java/ai/djl/ndarray/BaseNDManager.java | 150 +++++------- api/src/main/java/ai/djl/ndarray/NDScope.java | 83 +++++++ .../ai/djl/ndarray/refcount/Deallocator.java | 23 -- .../refcount/DeallocatorReference.java | 131 ---------- .../ai/djl/ndarray/refcount/RCConfig.java | 39 --- .../ai/djl/ndarray/refcount/RCException.java | 47 ---- .../ai/djl/ndarray/refcount/RCObject.java | 225 ------------------ .../java/ai/djl/ndarray/refcount/RCScope.java | 208 ---------------- .../ndarray/refcount/ReferenceCounter.java | 39 --- .../ai/djl/ndarray/refcount/package-info.java | 18 -- .../main/java/ai/djl/util/NativeResource.java | 92 +------ .../test/java/ai/djl/ndarray/NDScopeTest.java | 44 ++++ .../java/ai/djl/mxnet/engine/MxNDArray.java | 3 + .../java/ai/djl/pytorch/engine/PtNDArray.java | 25 +- .../pytorch/engine/PtNDArrayDeallocator.java | 49 ---- .../ai/djl/pytorch/refcount/RCObjectTest.java | 55 ----- .../ai/djl/pytorch/refcount/RCScopeTest.java | 105 -------- .../ai/djl/pytorch/refcount/package-info.java | 18 -- .../ai/djl/tensorflow/engine/TfNDArray.java | 3 + gradle.properties | 2 +- tools/conf/findbugs-exclude.xml | 12 - 21 files changed, 207 insertions(+), 1164 deletions(-) create mode 100644 api/src/main/java/ai/djl/ndarray/NDScope.java delete mode 100644 api/src/main/java/ai/djl/ndarray/refcount/Deallocator.java delete mode 100644 api/src/main/java/ai/djl/ndarray/refcount/DeallocatorReference.java delete mode 100644 api/src/main/java/ai/djl/ndarray/refcount/RCConfig.java delete mode 100644 api/src/main/java/ai/djl/ndarray/refcount/RCException.java delete mode 100644 api/src/main/java/ai/djl/ndarray/refcount/RCObject.java delete mode 100644 api/src/main/java/ai/djl/ndarray/refcount/RCScope.java delete mode 100644 api/src/main/java/ai/djl/ndarray/refcount/ReferenceCounter.java delete mode 100644 api/src/main/java/ai/djl/ndarray/refcount/package-info.java create mode 100644 api/src/test/java/ai/djl/ndarray/NDScopeTest.java delete mode 100644 engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayDeallocator.java delete mode 100644 engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/RCObjectTest.java delete mode 100644 engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/RCScopeTest.java delete mode 100644 engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/package-info.java diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index 8addec50d04..abbc3c31605 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -63,77 +63,6 @@ protected BaseNDManager(NDManager parent, Device device) { } } - /** - * Checks if the input buffer size is match expected data type. - * - * @param buffer the input buffer - * @param dataType the desired {@code DataType} - * @param expected the expected size - * @throws IllegalArgumentException if buffer size is invalid - */ - public static void validateBuffer(Buffer buffer, DataType dataType, int expected) { - boolean isByteBuffer = buffer instanceof ByteBuffer; - DataType type = DataType.fromBuffer(buffer); - if (type != dataType && !isByteBuffer) { - // It's ok if type != datatype and buffer is ByteBuffer, - // since buffer will be copied into ByteBuffer - throw new IllegalArgumentException( - "The input data type: " - + type - + " does not match target array data type: " - + dataType); - } - - int remaining = buffer.remaining(); - int expectedSize = isByteBuffer ? dataType.getNumOfBytes() * expected : expected; - if (remaining < expectedSize) { - throw new IllegalArgumentException( - "The NDArray size is: " + expected + ", but buffer size is: " + remaining); - } - if (remaining > expectedSize) { - logger.warn( - "Input buffer size is greater than the NDArray size, please set limit" - + " explicitly."); - buffer.limit(expectedSize); - } - } - - /** - * Copies data from the source {@code Buffer} to the target {@code ByteBuffer}. - * - * @param src the source {@code Buffer} - * @param target the target {@code ByteBuffer} - */ - public static void copyBuffer(Buffer src, ByteBuffer target) { - target.rewind(); - DataType inputType = DataType.fromBuffer(src); - switch (inputType) { - case FLOAT16: - target.asShortBuffer().put((ShortBuffer) src); - break; - case FLOAT32: - target.asFloatBuffer().put((FloatBuffer) src); - break; - case FLOAT64: - target.asDoubleBuffer().put((DoubleBuffer) src); - break; - case UINT8: - case INT8: - case BOOLEAN: - target.put((ByteBuffer) src); - break; - case INT32: - target.asIntBuffer().put((IntBuffer) src); - break; - case INT64: - target.asLongBuffer().put((LongBuffer) src); - break; - default: - throw new AssertionError("Unsupported datatype: " + inputType); - } - target.rewind(); - } - /** {@inheritDoc} */ @Override public final Device defaultDevice() { @@ -178,14 +107,14 @@ public NDList load(Path path) { /** {@inheritDoc} */ @Override - public String getName() { - return this.name == null ? uid : this.name; + public void setName(String name) { + this.name = name; } /** {@inheritDoc} */ @Override - public void setName(String name) { - this.name = name; + public String getName() { + return this.name == null ? uid : this.name; } /** {@inheritDoc} */ @@ -539,23 +468,74 @@ NDManager getAlternativeManager() { } /** - * Returns true if the resource is a resource of this manager. + * Checks if the input buffer size is match expected data type. * - * @param resource the resource to check - * @return true if the resource is a resource of this manager + * @param buffer the input buffer + * @param dataType the desired {@code DataType} + * @param expected the expected size + * @throws IllegalArgumentException if buffer size is invalid */ - public boolean hasResource(AutoCloseable resource) { - return this.resources.values().contains(resource); + public static void validateBuffer(Buffer buffer, DataType dataType, int expected) { + boolean isByteBuffer = buffer instanceof ByteBuffer; + DataType type = DataType.fromBuffer(buffer); + if (type != dataType && !isByteBuffer) { + // It's ok if type != datatype and buffer is ByteBuffer, + // since buffer will be copied into ByteBuffer + throw new IllegalArgumentException( + "The input data type: " + + type + + " does not match target array data type: " + + dataType); + } + + int remaining = buffer.remaining(); + int expectedSize = isByteBuffer ? dataType.getNumOfBytes() * expected : expected; + if (remaining < expectedSize) { + throw new IllegalArgumentException( + "The NDArray size is: " + expected + ", but buffer size is: " + remaining); + } + if (remaining > expectedSize) { + logger.warn( + "Input buffer size is greater than the NDArray size, please set limit" + + " explicitly."); + buffer.limit(expectedSize); + } } /** - * Returns true if the resource is a temporary resource. + * Copies data from the source {@code Buffer} to the target {@code ByteBuffer}. * - * @param resource the resource to check - * @return true if the resource is a temporary resource + * @param src the source {@code Buffer} + * @param target the target {@code ByteBuffer} */ - public boolean hasTempResource(AutoCloseable resource) { - return this.resources.values().contains(resource); + public static void copyBuffer(Buffer src, ByteBuffer target) { + target.rewind(); + DataType inputType = DataType.fromBuffer(src); + switch (inputType) { + case FLOAT16: + target.asShortBuffer().put((ShortBuffer) src); + break; + case FLOAT32: + target.asFloatBuffer().put((FloatBuffer) src); + break; + case FLOAT64: + target.asDoubleBuffer().put((DoubleBuffer) src); + break; + case UINT8: + case INT8: + case BOOLEAN: + target.put((ByteBuffer) src); + break; + case INT32: + target.asIntBuffer().put((IntBuffer) src); + break; + case INT64: + target.asLongBuffer().put((LongBuffer) src); + break; + default: + throw new AssertionError("Unsupported datatype: " + inputType); + } + target.rewind(); } protected static final class TempResource { diff --git a/api/src/main/java/ai/djl/ndarray/NDScope.java b/api/src/main/java/ai/djl/ndarray/NDScope.java new file mode 100644 index 00000000000..7adda1ff436 --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/NDScope.java @@ -0,0 +1,83 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.List; + +/** + * A class that tracks {@link NDResource} objects created in the try-with-resource block and close + * them automatically when out of the block scope. + * + *

This class has been derived from {@code org.bytedeco.javacpp.PointerScope} by Samuel Audet + */ +public class NDScope implements AutoCloseable { + + private static final ThreadLocal> SCOPE_STACK = + ThreadLocal.withInitial(ArrayDeque::new); + + private List resources; + + /** Constructs a new {@code NDScope} instance. */ + public NDScope() { + resources = new ArrayList<>(); + SCOPE_STACK.get().addLast(this); + } + + /** + * Registers {@link NDArray} object to this scope. + * + * @param array the {@link NDArray} object + */ + public static void register(NDArray array) { + Deque queue = SCOPE_STACK.get(); + if (queue.isEmpty()) { + return; + } + queue.getLast().resources.add(array); + } + + /** + * Unregisters {@link NDArray} object from this scope. + * + * @param array the {@link NDArray} object + */ + public static void unregister(NDArray array) { + Deque queue = SCOPE_STACK.get(); + if (queue.isEmpty()) { + return; + } + queue.getLast().resources.remove(array); + } + + /** {@inheritDoc} */ + @Override + public void close() { + for (NDArray array : resources) { + array.close(); + } + SCOPE_STACK.get().remove(this); + } + + /** + * A method that does nothing. + * + *

You may use it if you do not have a better way to suppress the warning of a created but + * not explicitly used scope. + */ + public void suppressNotUsedWarning() { + // do nothing + } +} diff --git a/api/src/main/java/ai/djl/ndarray/refcount/Deallocator.java b/api/src/main/java/ai/djl/ndarray/refcount/Deallocator.java deleted file mode 100644 index 1f180a687b2..00000000000 --- a/api/src/main/java/ai/djl/ndarray/refcount/Deallocator.java +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.ndarray.refcount; - -/** - * The interface to implement to produce a Deallocator usable by referenceCountedObject. - * - *

This interface has been derived from {@code org.bytedeco.javacpp.Pointer} by Samuel Audet - */ -public interface Deallocator { - /** The method to implement to produce a Deallocator usable by referenceCountedObject. */ - void deallocate(); -} diff --git a/api/src/main/java/ai/djl/ndarray/refcount/DeallocatorReference.java b/api/src/main/java/ai/djl/ndarray/refcount/DeallocatorReference.java deleted file mode 100644 index 1b056ededb6..00000000000 --- a/api/src/main/java/ai/djl/ndarray/refcount/DeallocatorReference.java +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.ndarray.refcount; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.lang.ref.PhantomReference; -import java.util.concurrent.atomic.AtomicInteger; - -/** - * A subclass of {@link PhantomReference} that also acts as a linked list to keep their references - * alive until they get garbage collected. Implements reference counting with an {@link - * AtomicInteger} count. - * - *

This class has been derived from {@code org.bytedeco.javacpp.Pointer} by Samuel Audet - */ -@SuppressWarnings({"PMD.MutableStaticState", "PMD.AvoidUsingVolatile"}) -public class DeallocatorReference extends PhantomReference - implements Deallocator, ReferenceCounter { - - private static final Logger logger = LoggerFactory.getLogger(DeallocatorReference.class); - static volatile DeallocatorReference head; - static volatile long totalCount; - volatile DeallocatorReference prev = this; - volatile DeallocatorReference next = this; - protected Deallocator deallocator; - AtomicInteger count; - - protected DeallocatorReference(RCObject p, Deallocator deallocator) { - super(p, null); - this.deallocator = deallocator; - this.count = new AtomicInteger(0); - } - - final void add() { - synchronized (DeallocatorReference.class) { - if (head == null) { - head = this; - prev = next = null; - } else { - prev = null; - next = head; - next.prev = head = this; - } - totalCount++; - } - } - - final void remove() { - if (prev == this && next == this) { - return; - } - synchronized (DeallocatorReference.class) { - if (prev == null) { - head = next; - } else { - prev.next = next; - } - if (next != null) { - next.prev = prev; - } - prev = next = this; - totalCount--; - } - } - - /** {@inheritDoc} */ - @Override - public void clear() { - super.clear(); - if (deallocator != null) { - if (logger.isDebugEnabled()) { - logger.trace("Collecting " + this); - } - deallocate(); - } - } - - /** {@inheritDoc} */ - @Override - public void deallocate() { - if (deallocator != null) { - deallocator.deallocate(); - deallocator = null; - } - } - - /** {@inheritDoc} */ - @Override - public void retain() { - if (deallocator != null) { - count.incrementAndGet(); - } - } - - /** {@inheritDoc} */ - @Override - public boolean release() { - if (deallocator != null && count.decrementAndGet() <= 0) { - if (logger.isDebugEnabled()) { - logger.trace("Releasing " + this); - } - deallocate(); - return true; - } - return false; - } - - /** {@inheritDoc} */ - @Override - public int count() { - return deallocator != null ? count.get() : -1; - } - - /** {@inheritDoc} */ - @Override - public String toString() { - return getClass().getName() + "[deallocator=" + deallocator + ",count=" + count + "]"; - } -} diff --git a/api/src/main/java/ai/djl/ndarray/refcount/RCConfig.java b/api/src/main/java/ai/djl/ndarray/refcount/RCConfig.java deleted file mode 100644 index b10349a19c1..00000000000 --- a/api/src/main/java/ai/djl/ndarray/refcount/RCConfig.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.ndarray.refcount; - -/** RCConfig holds configurable parameters. */ -public final class RCConfig { - - private static boolean verboseIfResourceAlreadyClosed; - - private RCConfig() {} - - /** - * If true, a verbose message is printed if a resource is already closed. - * - * @return true if the verboseIfResourceAlreadyClosed is set - */ - public static boolean isVerboseIfResourceAlreadyClosed() { - return verboseIfResourceAlreadyClosed; - } - - /** - * If true, a verbose message is printed if a resource is already closed. - * - * @param verboseIfResourceAlreadyClosed parameter to set - */ - public static void setVerboseIfResourceAlreadyClosed(boolean verboseIfResourceAlreadyClosed) { - RCConfig.verboseIfResourceAlreadyClosed = verboseIfResourceAlreadyClosed; - } -} diff --git a/api/src/main/java/ai/djl/ndarray/refcount/RCException.java b/api/src/main/java/ai/djl/ndarray/refcount/RCException.java deleted file mode 100644 index bbdc494eddb..00000000000 --- a/api/src/main/java/ai/djl/ndarray/refcount/RCException.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.ndarray.refcount; - -/** A runtime exception thrown within the reference counting package. */ -public class RCException extends RuntimeException { - - private static final long serialVersionUID = 1L; - - /** - * Constructs a new runtime exception with the specified detail message. - * - * @param message - the message to be displayed - */ - public RCException(String message) { - super(message); - } - - /** - * Constructs a new runtime exception with the specified detail message and cause. - * - * @param message - the message to be displayed - * @param cause - the cause of the exception - */ - public RCException(String message, Throwable cause) { - super(message, cause); - } - - /** - * Constructs a new runtime exception with the specified cause. - * - * @param cause - the cause of the exception - */ - public RCException(Throwable cause) { - super(cause); - } -} diff --git a/api/src/main/java/ai/djl/ndarray/refcount/RCObject.java b/api/src/main/java/ai/djl/ndarray/refcount/RCObject.java deleted file mode 100644 index 248ba6e04b0..00000000000 --- a/api/src/main/java/ai/djl/ndarray/refcount/RCObject.java +++ /dev/null @@ -1,225 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.ndarray.refcount; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.Iterator; - -/** - * All peer classes to native types must be descended from {@link RCObject} (reference counted - * object), the topmost class. - * - *

It is also possible to use a {@link RCScope} to keep track of a group of {@link RCObject} - * objects, and have them deallocated in a transparent but deterministic manner. - * - *

This class has been derived from {@code org.bytedeco.javacpp.Pointer} by Samuel Audet - */ -@SuppressWarnings("PMD.AvoidBranchingStatementAsLastInLoop") -public class RCObject { - - private static final Logger logger = LoggerFactory.getLogger(RCObject.class); - - private Deallocator deallocator; - - /** - * Returns {@link DeallocatorReference#totalCount}, current number of ReferenceCountedObjects - * tracked by deallocators. - * - * @return the total count - */ - public static long totalCount() { - return DeallocatorReference.totalCount; - } - - /** - * Returns {@link ReferenceCounter#count()} or -1 if no deallocator has been set. - * - * @return the count - */ - public int referenceCount() { - ReferenceCounter r = (ReferenceCounter) deallocator; - return r != null ? r.count() : -1; - } - - /** - * Returns {@link #deallocator}. - * - * @return the deallocator - */ - protected Deallocator deallocator() { - return deallocator; - } - - /** - * Sets the deallocator and returns this. Also clears current deallocator if not {@code null}. - * That is, it deallocates previously allocated memory. Should not be called more than once - * after allocation. - * - * @param deallocator the new deallocator - * @param

the type of the referenceCountedObject - * @return this referenceCountedObject - */ - protected

P deallocator(Deallocator deallocator) { - if (this.deallocator != null) { - if (logger.isDebugEnabled()) { - logger.debug("Predeallocating " + this); - } - this.deallocator.deallocate(); - this.deallocator = null; - } - if (deallocator != null) { - DeallocatorReference r = - deallocator instanceof DeallocatorReference - ? (DeallocatorReference) deallocator - : new DeallocatorReference(this, deallocator); - this.deallocator = r; - Iterator it = RCScope.getScopeIterator(); - if (it != null) { - while (it.hasNext()) { - try { - it.next().attach(this); - } catch (IllegalArgumentException e) { - // try the next scope down the stack - continue; - } - break; - } - } - } - - @SuppressWarnings("unchecked") - P p = (P) this; - return p; - } - - /** Calls {@code deallocate(true)}. */ - public void deallocate() { - deallocate(true); - } - - /** - * Explicitly manages native memory without waiting after the garbage collector. Has no effect - * if no deallocator was previously set with {@link #deallocator(Deallocator)}. - * - * @param deallocate if true, deallocates, else does not, but disables garbage collection - */ - public void deallocate(boolean deallocate) { - DeallocatorReference r = (DeallocatorReference) deallocator; - if (deallocate && deallocator != null) { - if (logger.isDebugEnabled()) { - logger.debug("Deallocating " + this); - } - deallocator.deallocate(); - deallocator = null; - // address = 0; - } - if (r != null) { - // remove from queue without calling the deallocator - r.deallocator = null; - r.clear(); - r.remove(); - r.deallocator = deallocator; - } - } - - /** - * Calls {@link ReferenceCounter#retain()}, incrementing the reference count by 1. Has no effect - * if no deallocator was previously set with {@link #deallocator(Deallocator)}. - * - * @param

the type of the referenceCountedObject - * @return this - */ - public

P retainReference() { - ReferenceCounter r = (ReferenceCounter) deallocator; - if (r != null) { - r.retain(); - } - @SuppressWarnings("unchecked") - P p = (P) this; - return p; - } - - /** - * Calls {@link ReferenceCounter#release()}, decrementing the reference count by 1, in turn - * deallocating this referenceCountedObject when the count drops to 0. Has no effect if no - * deallocator was previously set with {@link #deallocator(Deallocator)}. - * - * @return true when the count drops to 0 and deallocation has occurred - */ - public boolean releaseReference() { - DeallocatorReference r = (DeallocatorReference) deallocator; - if (r != null && r.release()) { - deallocator = null; - // address = 0; - r.clear(); - r.remove(); - return true; - } - return false; - } - - /** - * Calls in effect {@code memcpy(this.address + this.position, p.address + p.position, length)}, - * where {@code length = sizeof(p) * (p.limit - p.position)}. If limit == 0, it uses position + - * 1 instead. The way the methods were designed allows constructs such as {@code - * this.position(0).put(p.position(13).limit(42))}. - * - * @param p the referenceCountedObject from which to copy memory - * @param

the type of the referenceCountedObject - * @return this - */ - public

P put(RCObject p) { - @SuppressWarnings("unchecked") - P p2 = (P) this; - return p2; - } - - /** - * Calls in effect {@code memset(address + position, b, length)}, where {@code length = sizeof() - * * (limit - position)}. If limit == 0, it uses position + 1 instead. The way the methods were - * designed allows constructs such as {@code this.position(0).limit(13).fill(42)}; - * - * @param b the byte value to fill the memory with - * @param

the type of the referenceCountedObject - * @return this - */ - public

P fill(int b) { - @SuppressWarnings("unchecked") - P p = (P) this; - return p; - } - - /** - * Returns {@code fill(0)}. - * - * @param

the type of the referenceCountedObject - * @return this - */ - public

P zero() { - // repair warning: [unchecked] unchecked cast - @SuppressWarnings("unchecked") - P p = (P) this.fill(0); - return p; - } - - /** - * Returns whether the resource is null. - * - * @return whether the resource is null - */ - public boolean isNull() { - throw new UnsupportedOperationException("Not implemented."); - } -} diff --git a/api/src/main/java/ai/djl/ndarray/refcount/RCScope.java b/api/src/main/java/ai/djl/ndarray/refcount/RCScope.java deleted file mode 100644 index 11389f5bbec..00000000000 --- a/api/src/main/java/ai/djl/ndarray/refcount/RCScope.java +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.ndarray.refcount; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.ArrayDeque; -import java.util.Arrays; -import java.util.Deque; -import java.util.Iterator; - -/** - * {@link RCObject} objects attach themselves automatically on creation to the first {@link RCScope} - * found in {@link #SCOPE_STACK} that they can to based on the classes found in {@link #forClasses}. - * The user can then call {@link #deallocate()}, or rely on {@link #close()} to release in a timely - * fashion all attached referenceCountedObject objects, instead of relying on the garbage collector. - * - *

This class has been derived from {@code org.bytedeco.javacpp.PointerScope} by Samuel Audet - */ -public class RCScope implements AutoCloseable { - /** - * A thread-local stack of {@link RCScope} objects. referenceCountedObject objects attach - * themselves automatically to the first one they can to on the stack. - */ - static final ThreadLocal> SCOPE_STACK = - new ThreadLocal>() { - @Override - protected Deque initialValue() { - return new ArrayDeque(); - } - }; - - private static final Logger logger = LoggerFactory.getLogger(RCScope.class); - /** The stack keeping references to attached {@link RCObject} objects. */ - Deque referenceCountedObjectStack = new ArrayDeque<>(); - /** When not empty, indicates the classes of objects that are allowed to be attached. */ - Class[] forClasses; - /** - * When set to true, the next call to {@link #close()} does not release but resets this - * variable. - */ - boolean extend; - - /** - * Creates a new scope accepting all referenceCountedObject types and pushes itself on the - * {@link #SCOPE_STACK}. - */ - public RCScope() { - this((Class[]) null); - } - - /** - * Initializes {@link #forClasses}, and adds itself as first (push) on the {@link #SCOPE_STACK}. - * - * @param forClasses the classes of objects that are allowed to be attached - */ - @SafeVarargs - @SuppressWarnings("varargs") - public RCScope(Class... forClasses) { - if (logger.isDebugEnabled()) { - logger.debug("Opening " + this); - } - this.forClasses = forClasses; - SCOPE_STACK.get().addFirst(this); - } - - /** - * Returns {@code SCOPE_STACK.get().peekFirst()} (peek), the last opened scope not yet closed. - * - * @return the last opened scope not yet closed - */ - public static RCScope getInnerScope() { - return SCOPE_STACK.get().peekFirst(); - } - - /** - * Returns {@code SCOPE_STACK.get().iterator()}, all scopes not yet closed. - * - * @return all scopes not yet closed - */ - public static Iterator getScopeIterator() { - return SCOPE_STACK.get().iterator(); - } - - /** - * When not empty, returns the classes of objects that are allowed to be attached. - * - * @return the classes of objects that are allowed to be attached - */ - public Class[] forClasses() { - return forClasses; - } - - /** - * Pushes the referenceCountedObject onto the {@link #referenceCountedObjectStack} of this Scope - * and calls {@link RCObject#retainReference()}. - * - * @param p the referenceCountedObject to attach - * @return the referenceCountedObject - * @throws IllegalArgumentException when it is not an instance of a class in {@link - * #forClasses}. - */ - public RCScope attach(RCObject p) { - if (logger.isDebugEnabled()) { - logger.debug("Attaching " + p + " to " + this); - } - if (forClasses != null && forClasses.length > 0) { - boolean found = false; - for (Class c : forClasses) { - if (c != null && c.isInstance(p)) { - found = true; - break; - } - } - if (!found) { - throw new IllegalArgumentException( - p - + " is not an instance of a class in forClasses: " - + Arrays.toString(forClasses)); - } - } - referenceCountedObjectStack.push(p); - p.retainReference(); - return this; - } - - /** - * Removes the referenceCountedObject from the {@link #referenceCountedObjectStack} of this - * Scope and calls {@link RCObject#releaseReference()}. - * - * @param p the referenceCountedObject to detach - * @return the referenceCountedObject - */ - public RCScope detach(RCObject p) { - if (logger.isDebugEnabled()) { - logger.debug("Detaching " + p + " from " + this); - } - referenceCountedObjectStack.remove(p); - p.releaseReference(); - return this; - } - - /** - * Extends the life of this scope past the next call to {@link #close()} by setting the {@link - * #extend} flag. - * - * @return this scope - */ - public RCScope extend() { - if (logger.isDebugEnabled()) { - logger.debug("Extending " + this); - } - extend = true; - return this; - } - - /** - * Pops from {@link #referenceCountedObjectStack} all attached ReferenceCountedObjects, calls - * {@link RCObject#releaseReference()} on them, unless extended, in which case it only resets - * the {@link #extend} flag instead, and finally removes itself from {@link #SCOPE_STACK}. - */ - @Override - public void close() { - if (logger.isDebugEnabled()) { - logger.debug("Closing " + this); - } - if (extend) { - extend = false; - } else { - while (referenceCountedObjectStack.size() > 0) { - referenceCountedObjectStack.pop().releaseReference(); - } - } - SCOPE_STACK.get().remove(this); - } - - /** - * Pops from {@link #referenceCountedObjectStack} all attached ReferenceCountedObjects, and - * calls {@link RCObject#deallocate()} on them. - */ - public void deallocate() { - if (logger.isDebugEnabled()) { - logger.debug("Deallocating " + this); - } - while (referenceCountedObjectStack.size() > 0) { - referenceCountedObjectStack.pop().deallocate(); - } - } - - /** - * A method that does nothing. You may use it if you do not have a better way to suppress the - * warning of a created but not explicitly used scope. - */ - public void suppressNotUsedWarning() { - // do nothing - } -} diff --git a/api/src/main/java/ai/djl/ndarray/refcount/ReferenceCounter.java b/api/src/main/java/ai/djl/ndarray/refcount/ReferenceCounter.java deleted file mode 100644 index 25ef134fba1..00000000000 --- a/api/src/main/java/ai/djl/ndarray/refcount/ReferenceCounter.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.ndarray.refcount; - -/** - * The ReferenceCounter interface. - * - *

This interface has been derived from {@code org.bytedeco.javacpp.Pointer} by Samuel Audet - */ -public interface ReferenceCounter { - - /** Increments the reference count by 1 starting from initially 0. */ - void retain(); - - /** - * Decrements the reference count by 1, in turn deallocating this Pointer when the count drops - * to 0. - * - * @return true when the count drops to 0 and deallocation has occurred - */ - boolean release(); - - /** - * Returns the count value. - * - * @return the count value - */ - int count(); -} diff --git a/api/src/main/java/ai/djl/ndarray/refcount/package-info.java b/api/src/main/java/ai/djl/ndarray/refcount/package-info.java deleted file mode 100644 index cbee2abc6d1..00000000000 --- a/api/src/main/java/ai/djl/ndarray/refcount/package-info.java +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ - -/** - * Contains a reference counting implementation derived from JavaCPP's Pointer and PointerScope - * helping to avoid memory leaks in {@link ai.djl.ndarray.NDArray}. - */ -package ai.djl.ndarray.refcount; diff --git a/api/src/main/java/ai/djl/util/NativeResource.java b/api/src/main/java/ai/djl/util/NativeResource.java index c0a12a0dd46..65aa8f0085a 100644 --- a/api/src/main/java/ai/djl/util/NativeResource.java +++ b/api/src/main/java/ai/djl/util/NativeResource.java @@ -12,15 +12,8 @@ */ package ai.djl.util; -import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.refcount.RCConfig; -import ai.djl.ndarray.refcount.RCObject; - import com.sun.jna.Pointer; -import java.text.MessageFormat; -import java.time.Instant; -import java.util.Arrays; import java.util.concurrent.atomic.AtomicReference; /** @@ -29,74 +22,14 @@ * * @param the resource that could map to a native pointer or java object */ -@SuppressWarnings("PMD.ConstructorCallsOverridableMethod") -public abstract class NativeResource extends RCObject implements AutoCloseable { +public abstract class NativeResource implements AutoCloseable { protected final AtomicReference handle; - protected Instant creationTime; - - protected String uid; - private String creationStackTraceAsString; - private String closingStackTraceAsString; - - /** Constructs a new {@code NativeResource}. */ - public NativeResource() { - // super(); - handle = new AtomicReference<>(); - this.creationTime = Instant.now(); - if (RCConfig.isVerboseIfResourceAlreadyClosed()) { - creationStackTraceAsString = stackTraceAsString(); - } - } + private String uid; protected NativeResource(T handle) { this.handle = new AtomicReference<>(handle); uid = handle.toString(); - this.creationTime = Instant.now(); - if (RCConfig.isVerboseIfResourceAlreadyClosed()) { - creationStackTraceAsString = stackTraceAsString(); - } - } - - private String fingerPrintOfNativeResourceWithStackTraceFromCreation() { - String name = "NO_NAME"; - if (this instanceof NDArray) { - name = ((NDArray) this).getName(); - } - return MessageFormat.format( - "NDArray named \"{0}\" identified by (uid:{1};createdAt:{2}) \n" - + "call stack at creation...{3}\n" - + "######### \n" - + "call stack at closing...{4}\n" - + "#########", - name, - getUid(), - creationTime, - creationStackTraceAsString, - closingStackTraceAsString); - } - - /** - * Returns the current stack trace as a string. - * - * @return the current stack trace as a string - */ - public static String stackTraceAsString() { - StringBuilder buf = new StringBuilder(); - Arrays.stream(Thread.currentThread().getStackTrace()) - .forEach( - s -> - buf.append( - "\nat " - + s.getClassName() - + "." - + s.getMethodName() - + "(" - + s.getFileName() - + ":" - + s.getLineNumber() - + ")")); - return buf.toString(); } /** @@ -116,11 +49,7 @@ public boolean isReleased() { public T getHandle() { T reference = handle.get(); if (reference == null) { - String message = "Native resource has been released already. "; - if (RCConfig.isVerboseIfResourceAlreadyClosed()) { - message += fingerPrintOfNativeResourceWithStackTraceFromCreation(); - } - throw new IllegalStateException(message); + throw new IllegalStateException("Native resource has been release already."); } return reference; } @@ -134,21 +63,6 @@ public final String getUid() { return uid; } - /** - * Sets the closingStackTraceAsString. - * - * @param closingStackTraceAsString the closingStackTraceAsString to set - */ - protected void setClosingStackTraceAsString(String closingStackTraceAsString) { - this.closingStackTraceAsString = closingStackTraceAsString; - } - - /** {@inheritDoc} */ - @Override - public boolean isNull() { - return handle.get() == null; - } - /** {@inheritDoc} */ @Override public void close() { diff --git a/api/src/test/java/ai/djl/ndarray/NDScopeTest.java b/api/src/test/java/ai/djl/ndarray/NDScopeTest.java new file mode 100644 index 00000000000..b3f4944387c --- /dev/null +++ b/api/src/test/java/ai/djl/ndarray/NDScopeTest.java @@ -0,0 +1,44 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray; + +import org.testng.Assert; +import org.testng.annotations.Test; + +public class NDScopeTest { + + @Test + @SuppressWarnings("try") + public void testNDScope() { + NDArray detached; + NDArray inside; + try (NDManager manager = NDManager.newBaseManager()) { + try (NDScope scope = new NDScope()) { + scope.suppressNotUsedWarning(); + try (NDScope ignore = new NDScope()) { + inside = manager.create(new int[] {1}); + // not tracked by any NDScope, but still managed by NDManager + NDScope.unregister(inside); + } + + detached = manager.create(new int[] {1}); + detached.detach(); // detached from NDManager and NDScope + } + + Assert.assertFalse(inside.isReleased()); + } + Assert.assertTrue(inside.isReleased()); + Assert.assertFalse(detached.isReleased()); + detached.close(); + } +} diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index 10dbf7d14db..e71537cd87a 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -19,6 +19,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.NDScope; import ai.djl.ndarray.internal.NDArrayEx; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; @@ -89,6 +90,7 @@ public class MxNDArray extends NativeResource implements LazyNDArray { this.manager = manager; mxNDArrayEx = new MxNDArrayEx(this); manager.attachInternal(getUid(), this); + NDScope.register(this); } /** @@ -187,6 +189,7 @@ public void tempAttach(NDManager manager) { public void detach() { manager.detachInternal(getUid()); manager = MxNDManager.getSystemManager(); + NDScope.unregister(this); } private NDArray duplicate( diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index d06e183a48f..3c452454186 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -17,8 +17,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; -import ai.djl.ndarray.refcount.RCConfig; -import ai.djl.ndarray.refcount.RCObject; +import ai.djl.ndarray.NDScope; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.ndarray.types.SparseFormat; @@ -66,7 +65,7 @@ public PtNDArray(PtNDManager manager, long handle) { this.manager = manager; this.ptNDArrayEx = new PtNDArrayEx(this); manager.attachInternal(getUid(), this); - deallocator(new PtNDArrayDeallocator(this)); + NDScope.register(this); } /** @@ -79,11 +78,11 @@ public PtNDArray(PtNDManager manager, long handle) { */ public PtNDArray(PtNDManager manager, long handle, ByteBuffer data) { super(handle); - deallocator(new PtNDArrayDeallocator(this)); this.manager = manager; this.ptNDArrayEx = new PtNDArrayEx(this); manager.attachInternal(getUid(), this); dataRef = data; + NDScope.register(this); } /** @@ -96,22 +95,12 @@ public PtNDArray(PtNDManager manager, long handle, ByteBuffer data) { */ public PtNDArray(PtNDManager manager, String[] strs, Shape shape) { super(-1L); - deallocator(new PtNDArrayDeallocator(this)); this.manager = manager; this.strs = strs; this.shape = shape; this.dataType = DataType.STRING; } - /** - * Deallocates the native memory associated with the specified {@link RCObject}. - * - * @param rco the reference count object - */ - public static void deallocate(RCObject rco) { - ((PtNDArray) rco).close(); - } - /** {@inheritDoc} */ @Override public PtNDManager getManager() { @@ -377,6 +366,7 @@ public void tempAttach(NDManager manager) { public void detach() { manager.detachInternal(getUid()); manager = PtNDManager.getSystemManager(); + NDScope.unregister(this); } /** {@inheritDoc} */ @@ -1605,16 +1595,11 @@ public int hashCode() { /** {@inheritDoc} */ @Override public void close() { - if (RCConfig.isVerboseIfResourceAlreadyClosed()) { - setClosingStackTraceAsString(stackTraceAsString()); - } Long pointer = handle.getAndSet(null); if (pointer != null && pointer != -1) { JniUtils.deleteNDArray(pointer); } - if (manager != null) { - manager.detachInternal(getUid()); - } + manager.detachInternal(getUid()); dataRef = null; } } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayDeallocator.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayDeallocator.java deleted file mode 100644 index 97590a7c31a..00000000000 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayDeallocator.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.pytorch.engine; - -import ai.djl.ndarray.refcount.Deallocator; -import ai.djl.ndarray.refcount.DeallocatorReference; - -/** - * A {@link Deallocator} that calls, during garbage collection, the method {@link - * PtNDArray#deallocate()} from the referenceCountedObject of type {@link PtNDArray}. - * - *

This class has been derived from {@code org.bytedeco.javacpp.Pointer} by Samuel Audet - */ -public class PtNDArrayDeallocator extends DeallocatorReference { - PtNDArray referenceCountedObject; - - /** - * Constructs and initializes a {@code PtNDArrayDeallocator} with a {@link PtNDArray} to. - * - * @param p - the {@link PtNDArray} to be deallocated - */ - public PtNDArrayDeallocator(PtNDArray p) { - super(p, null); - this.deallocator = this; - this.referenceCountedObject = p; - } - - /** {@inheritDoc} */ - @Override - public void deallocate() { - PtNDArray.deallocate(referenceCountedObject); - } - - /** {@inheritDoc} */ - @Override - public String toString() { - return getClass().getName() + "[referenceCountedObject=" + referenceCountedObject + "]"; - } -} diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/RCObjectTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/RCObjectTest.java deleted file mode 100644 index 43b5d36e7c3..00000000000 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/RCObjectTest.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.pytorch.refcount; - -import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.NDManager; -import ai.djl.ndarray.refcount.RCObject; - -import org.testng.Assert; -import org.testng.annotations.Test; - -public class RCObjectTest { - - @Test - public void testNDArraySimpleLifecycle() { - System.out.println("NDArray simple lifecycle"); - try (NDManager manager = NDManager.newBaseManager()) { - NDArray array1 = manager.create(new int[] {1, 2, 3}); - RCObject rco1 = (RCObject) array1; - Assert.assertEquals(rco1.referenceCount(), 0); - rco1.retainReference(); - Assert.assertEquals(rco1.referenceCount(), 1); - Assert.assertTrue(rco1.releaseReference()); - Assert.assertEquals(rco1.referenceCount(), -1); - } - } - - @Test - public void testNDArraySimpleLifecycle2() { - System.out.println("NDArray simple lifecycle 2"); - try (NDManager manager = NDManager.newBaseManager()) { - NDArray array1 = manager.create(new int[] {1, 2, 3}); - RCObject rco1 = (RCObject) array1; - Assert.assertEquals(rco1.referenceCount(), 0); - rco1.retainReference(); - Assert.assertEquals(rco1.referenceCount(), 1); - rco1.retainReference(); - Assert.assertEquals(rco1.referenceCount(), 2); - Assert.assertFalse(rco1.releaseReference()); - Assert.assertEquals(rco1.referenceCount(), 1); - Assert.assertTrue(rco1.releaseReference()); - Assert.assertEquals(rco1.referenceCount(), -1); - } - } -} diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/RCScopeTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/RCScopeTest.java deleted file mode 100644 index 67261198946..00000000000 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/RCScopeTest.java +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.pytorch.refcount; - -import ai.djl.ndarray.NDManager; -import ai.djl.ndarray.refcount.RCObject; -import ai.djl.ndarray.refcount.RCScope; -import ai.djl.pytorch.engine.PtNDArray; - -import org.testng.Assert; -import org.testng.annotations.Test; - -public class RCScopeTest { - - @Test - public void testRCScopeBaseProperties() { - System.out.println("RCScope base properties"); - - try (NDManager manager = NDManager.newBaseManager()) { - RCObject outside = (RCObject) manager.create(new int[] {1}); - RCObject attached = (RCObject) manager.create(new int[] {1}); - RCObject detached; - RCObject inside; - RCObject inside1; - RCObject inside2; - RCObject retained1; - RCObject retained2; - RCObject inside5; - - try (RCScope scope = new RCScope()) { - scope.attach(attached); - - detached = (RCObject) manager.create(new int[] {1}); - detached.retainReference(); - scope.detach(detached); - - inside = (RCObject) manager.create(new int[] {1}); - try (RCScope scope1 = new RCScope()) { - scope1.suppressNotUsedWarning(); - inside1 = (RCObject) manager.create(new int[] {1}); - inside2 = (RCObject) manager.create(new int[] {1}); - } - try (RCScope scope2 = new RCScope()) { - scope2.suppressNotUsedWarning(); - retained1 = (RCObject) manager.create(new int[] {1}); - retained2 = (RCObject) manager.create(new int[] {1}); - retained1.retainReference(); - scope.attach(retained2); - } - retained2.retainReference(); - inside5 = (RCObject) manager.create(new int[] {1}); - } - - RCObject outside2 = (RCObject) manager.create(new int[] {1}); - - Assert.assertFalse(outside.isNull()); - Assert.assertTrue(attached.isNull()); - Assert.assertFalse(detached.isNull()); - Assert.assertTrue(inside.isNull()); - Assert.assertTrue(inside1.isNull()); - Assert.assertTrue(inside2.isNull()); - Assert.assertFalse(retained1.isNull()); - Assert.assertFalse(retained2.isNull()); - Assert.assertTrue(inside5.isNull()); - Assert.assertFalse(outside2.isNull()); - - outside.releaseReference(); - detached.releaseReference(); - retained1.releaseReference(); - retained2.releaseReference(); - outside2.releaseReference(); - - Assert.assertTrue(outside.isNull()); - Assert.assertTrue(detached.isNull()); - Assert.assertTrue(retained1.isNull()); - Assert.assertTrue(retained2.isNull()); - Assert.assertTrue(outside2.isNull()); - } - } - - @Test - public void testRCScopeDetachingFromManager() { - System.out.println("RCScope detaching from manager"); - PtNDArray inside; - - try (NDManager manager = NDManager.newBaseManager()) { - try (RCScope scope = new RCScope()) { - scope.suppressNotUsedWarning(); - inside = (PtNDArray) manager.create(new int[] {1}); - } - Assert.assertFalse(inside.getManager().hasResource(inside)); - Assert.assertFalse(inside.getManager().hasTempResource(inside)); - } - } -} diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/package-info.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/package-info.java deleted file mode 100644 index 1581cabfd1b..00000000000 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/refcount/package-info.java +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ - -/** - * Contains test for the reference counting implementation derived from JavaCPP's Pointer and PointerScope. - * The test are derived from the JavaCPP's PointerTest. - */ -package ai.djl.pytorch.refcount; diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index 920630b5236..efd0dde322b 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -18,6 +18,7 @@ import ai.djl.ndarray.NDArrays; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.NDScope; import ai.djl.ndarray.internal.NDArrayEx; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; @@ -55,6 +56,7 @@ public class TfNDArray extends NativeResource implements NDArr this.manager = manager; manager.attachInternal(getUid(), this); tfNDArrayEx = new TfNDArrayEx(this); + NDScope.register(this); } TfNDArray(TfNDManager manager, TFE_TensorHandle handle, TF_Tensor tensor) { @@ -263,6 +265,7 @@ public void tempAttach(NDManager manager) { public void detach() { manager.detachInternal(getUid()); manager = TfNDManager.getSystemManager(); + NDScope.unregister(this); } /** {@inheritDoc} */ diff --git a/gradle.properties b/gradle.properties index 35c443ce24d..ed47e5387ec 100644 --- a/gradle.properties +++ b/gradle.properties @@ -43,6 +43,6 @@ tablesaw_version=0.43.1 spark_version=3.2.2 antlr_version=4.9.3 -testng_version=7.7.1 +testng_version=7.7.0 junit_version=4.13.2 mockito_version=4.8.0 diff --git a/tools/conf/findbugs-exclude.xml b/tools/conf/findbugs-exclude.xml index c54c777e2d7..b36584a1714 100644 --- a/tools/conf/findbugs-exclude.xml +++ b/tools/conf/findbugs-exclude.xml @@ -33,16 +33,4 @@ - - - - - - - - - - - - From 8c5c50292a243ea36cec1938a4b9625bf438ebda Mon Sep 17 00:00:00 2001 From: enpasos Date: Thu, 26 Jan 2023 12:40:30 +0100 Subject: [PATCH 04/10] missing NDScope.register --- .../src/main/java/ai/djl/pytorch/engine/PtNDArray.java | 1 + 1 file changed, 1 insertion(+) diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index 3c452454186..75dcffe7bbb 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -99,6 +99,7 @@ public PtNDArray(PtNDManager manager, String[] strs, Shape shape) { this.strs = strs; this.shape = shape; this.dataType = DataType.STRING; + NDScope.register(this); } /** {@inheritDoc} */ From 85e3f3d7d72c6834cda12e46e806b7ccb33e060e Mon Sep 17 00:00:00 2001 From: enpasos Date: Thu, 26 Jan 2023 12:51:49 +0100 Subject: [PATCH 05/10] enhanced test to sport subtle bug on equals method + fix --- api/src/test/java/ai/djl/ndarray/NDScopeTest.java | 3 +++ .../src/main/java/ai/djl/pytorch/engine/PtNDArray.java | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/api/src/test/java/ai/djl/ndarray/NDScopeTest.java b/api/src/test/java/ai/djl/ndarray/NDScopeTest.java index b3f4944387c..ab8268ca6f0 100644 --- a/api/src/test/java/ai/djl/ndarray/NDScopeTest.java +++ b/api/src/test/java/ai/djl/ndarray/NDScopeTest.java @@ -22,10 +22,13 @@ public class NDScopeTest { public void testNDScope() { NDArray detached; NDArray inside; + NDArray uninvolved; try (NDManager manager = NDManager.newBaseManager()) { try (NDScope scope = new NDScope()) { scope.suppressNotUsedWarning(); try (NDScope ignore = new NDScope()) { + uninvolved = manager.create(new int[] {1}); + uninvolved.close(); inside = manager.create(new int[] {1}); // not tracked by any NDScope, but still managed by NDManager NDScope.unregister(inside); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index 75dcffe7bbb..ad5f9c8d7c4 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -1582,6 +1582,10 @@ public String toString() { @Override public boolean equals(Object obj) { if (obj instanceof NDArray) { + // do no compare content if obj is released + if (((NDArray) obj).isReleased()) { + return this == obj; + } return contentEquals((NDArray) obj); } return false; From c0b94ac783d54f92e7c173fb81186242e43a0556 Mon Sep 17 00:00:00 2001 From: enpasos Date: Thu, 26 Jan 2023 13:19:53 +0100 Subject: [PATCH 06/10] problem on unregister: if two NDArrays on different device equals throws an Exception - fixed --- .../src/main/java/ai/djl/pytorch/engine/PtNDArray.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index ad5f9c8d7c4..9ecbcc84729 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -1586,6 +1586,9 @@ public boolean equals(Object obj) { if (((NDArray) obj).isReleased()) { return this == obj; } + if (((NDArray) obj).getManager() != manager) { + return false; + } return contentEquals((NDArray) obj); } return false; From af04eab0a970003d9dd5859a49ef78b0eab72932 Mon Sep 17 00:00:00 2001 From: enpasos Date: Thu, 26 Jan 2023 14:36:39 +0100 Subject: [PATCH 07/10] removed unregister from detach --- api/src/test/java/ai/djl/ndarray/NDScopeTest.java | 3 ++- .../src/main/java/ai/djl/mxnet/engine/MxNDArray.java | 1 - .../src/main/java/ai/djl/pytorch/engine/PtNDArray.java | 1 - .../src/main/java/ai/djl/tensorflow/engine/TfNDArray.java | 1 - 4 files changed, 2 insertions(+), 4 deletions(-) diff --git a/api/src/test/java/ai/djl/ndarray/NDScopeTest.java b/api/src/test/java/ai/djl/ndarray/NDScopeTest.java index ab8268ca6f0..eef151fec11 100644 --- a/api/src/test/java/ai/djl/ndarray/NDScopeTest.java +++ b/api/src/test/java/ai/djl/ndarray/NDScopeTest.java @@ -35,7 +35,8 @@ public void testNDScope() { } detached = manager.create(new int[] {1}); - detached.detach(); // detached from NDManager and NDScope + detached.detach(); // detached from NDManager + NDScope.unregister(detached); // and unregistered from NDScope } Assert.assertFalse(inside.isReleased()); diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index e71537cd87a..e62c593fc95 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -189,7 +189,6 @@ public void tempAttach(NDManager manager) { public void detach() { manager.detachInternal(getUid()); manager = MxNDManager.getSystemManager(); - NDScope.unregister(this); } private NDArray duplicate( diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index 9ecbcc84729..678f2fe5848 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -367,7 +367,6 @@ public void tempAttach(NDManager manager) { public void detach() { manager.detachInternal(getUid()); manager = PtNDManager.getSystemManager(); - NDScope.unregister(this); } /** {@inheritDoc} */ diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index efd0dde322b..aea57521137 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -265,7 +265,6 @@ public void tempAttach(NDManager manager) { public void detach() { manager.detachInternal(getUid()); manager = TfNDManager.getSystemManager(); - NDScope.unregister(this); } /** {@inheritDoc} */ From d0530d09957aa34b5417e1db5aa56ab9eedb50b6 Mon Sep 17 00:00:00 2001 From: enpasos Date: Thu, 26 Jan 2023 18:16:55 +0100 Subject: [PATCH 08/10] equals solved in NDScope internally - test on instance --- api/src/main/java/ai/djl/ndarray/NDScope.java | 31 ++++++++++++++++--- .../java/ai/djl/pytorch/engine/PtNDArray.java | 7 ----- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/NDScope.java b/api/src/main/java/ai/djl/ndarray/NDScope.java index 7adda1ff436..9380ff295f5 100644 --- a/api/src/main/java/ai/djl/ndarray/NDScope.java +++ b/api/src/main/java/ai/djl/ndarray/NDScope.java @@ -28,7 +28,7 @@ public class NDScope implements AutoCloseable { private static final ThreadLocal> SCOPE_STACK = ThreadLocal.withInitial(ArrayDeque::new); - private List resources; + private List resources; /** Constructs a new {@code NDScope} instance. */ public NDScope() { @@ -46,7 +46,7 @@ public static void register(NDArray array) { if (queue.isEmpty()) { return; } - queue.getLast().resources.add(array); + queue.getLast().resources.add(new NDArrayWrapper(array)); } /** @@ -59,14 +59,14 @@ public static void unregister(NDArray array) { if (queue.isEmpty()) { return; } - queue.getLast().resources.remove(array); + queue.getLast().resources.remove(new NDArrayWrapper(array)); } /** {@inheritDoc} */ @Override public void close() { - for (NDArray array : resources) { - array.close(); + for (NDArrayWrapper arrayWrapper : resources) { + arrayWrapper.array.close(); } SCOPE_STACK.get().remove(this); } @@ -80,4 +80,25 @@ public void close() { public void suppressNotUsedWarning() { // do nothing } + + private static class NDArrayWrapper { + NDArray array; + + public NDArrayWrapper(NDArray array) { + this.array = array; + } + + @Override + public int hashCode() { + return 1; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof NDArrayWrapper)) { + return false; + } + return ((NDArrayWrapper) o).array == array; + } + } } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index 678f2fe5848..f77880da789 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -1581,13 +1581,6 @@ public String toString() { @Override public boolean equals(Object obj) { if (obj instanceof NDArray) { - // do no compare content if obj is released - if (((NDArray) obj).isReleased()) { - return this == obj; - } - if (((NDArray) obj).getManager() != manager) { - return false; - } return contentEquals((NDArray) obj); } return false; From 35f403bf16b146b2fb921f532650a30200d54e12 Mon Sep 17 00:00:00 2001 From: enpasos Date: Thu, 26 Jan 2023 18:44:47 +0100 Subject: [PATCH 09/10] Update api/src/main/java/ai/djl/ndarray/NDScope.java Co-authored-by: Frank Liu --- api/src/main/java/ai/djl/ndarray/NDScope.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/src/main/java/ai/djl/ndarray/NDScope.java b/api/src/main/java/ai/djl/ndarray/NDScope.java index 9380ff295f5..0bb56d2d5c4 100644 --- a/api/src/main/java/ai/djl/ndarray/NDScope.java +++ b/api/src/main/java/ai/djl/ndarray/NDScope.java @@ -28,7 +28,7 @@ public class NDScope implements AutoCloseable { private static final ThreadLocal> SCOPE_STACK = ThreadLocal.withInitial(ArrayDeque::new); - private List resources; + private IdentityHashMap resources; /** Constructs a new {@code NDScope} instance. */ public NDScope() { From dcf14141faed2f188ae10cc250ad26e36c1dbe5a Mon Sep 17 00:00:00 2001 From: enpasos Date: Thu, 26 Jan 2023 18:54:38 +0100 Subject: [PATCH 10/10] using IdentityHashMap --- api/src/main/java/ai/djl/ndarray/NDScope.java | 34 ++++--------------- 1 file changed, 6 insertions(+), 28 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/NDScope.java b/api/src/main/java/ai/djl/ndarray/NDScope.java index 0bb56d2d5c4..b705d08c0bf 100644 --- a/api/src/main/java/ai/djl/ndarray/NDScope.java +++ b/api/src/main/java/ai/djl/ndarray/NDScope.java @@ -13,9 +13,8 @@ package ai.djl.ndarray; import java.util.ArrayDeque; -import java.util.ArrayList; import java.util.Deque; -import java.util.List; +import java.util.IdentityHashMap; /** * A class that tracks {@link NDResource} objects created in the try-with-resource block and close @@ -32,7 +31,7 @@ public class NDScope implements AutoCloseable { /** Constructs a new {@code NDScope} instance. */ public NDScope() { - resources = new ArrayList<>(); + resources = new IdentityHashMap<>(); SCOPE_STACK.get().addLast(this); } @@ -46,7 +45,7 @@ public static void register(NDArray array) { if (queue.isEmpty()) { return; } - queue.getLast().resources.add(new NDArrayWrapper(array)); + queue.getLast().resources.put(array, array); } /** @@ -59,14 +58,14 @@ public static void unregister(NDArray array) { if (queue.isEmpty()) { return; } - queue.getLast().resources.remove(new NDArrayWrapper(array)); + queue.getLast().resources.remove(array); } /** {@inheritDoc} */ @Override public void close() { - for (NDArrayWrapper arrayWrapper : resources) { - arrayWrapper.array.close(); + for (NDArray array : resources.keySet()) { + array.close(); } SCOPE_STACK.get().remove(this); } @@ -80,25 +79,4 @@ public void close() { public void suppressNotUsedWarning() { // do nothing } - - private static class NDArrayWrapper { - NDArray array; - - public NDArrayWrapper(NDArray array) { - this.array = array; - } - - @Override - public int hashCode() { - return 1; - } - - @Override - public boolean equals(Object o) { - if (!(o instanceof NDArrayWrapper)) { - return false; - } - return ((NDArrayWrapper) o).array == array; - } - } }