Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optionally close orphaned NDArrays using Java garbage collection #2273

Closed
wants to merge 37 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
1552f96
just the poc, compiles without testing
Dec 15, 2022
52cbaf6
moved PtNDArray interface out
Dec 15, 2022
5d67a61
creating NDArrays with or without proxy
Dec 15, 2022
396da57
build without test, poc uses switch successfully
Dec 15, 2022
f49c4b4
build without test and publishesToMavenLocal
Dec 15, 2022
8b06735
some proxy handling fixes
Dec 15, 2022
6fc1879
removed a logging
Dec 15, 2022
7204586
fixed double wrapping bug
Dec 16, 2022
8b52b01
catch exception silently if resource already closed
Dec 16, 2022
335586c
methods to remove NDManager without resources
Dec 16, 2022
320ff1d
renamed the switch to garbageCollectionOn
Dec 16, 2022
e0da7f3
add switch for gc in model
Dec 17, 2022
811bce8
common name for the switch
Dec 19, 2022
043592e
variable details on debugDump
Dec 19, 2022
e79c553
fixed a memory leak
Dec 21, 2022
5165335
global switch
Dec 27, 2022
bdc18a6
Merge remote-tracking branch 'origin/master' into gc-orphaned-resources
Dec 30, 2022
66a613a
sync fork
Dec 30, 2022
687e307
opened LayerNorm.Builder for inheritance
Dec 31, 2022
7062e5c
Merge remote-tracking branch 'origin/master' into gc-orphaned-resources
Dec 31, 2022
75bb5f2
uid-counter, getImplementation, debugCountNDArrays,getNumOfNDArraysIn…
Jan 4, 2023
d76ac55
Merge remote-tracking branch 'origin/master' into gc-orphaned-resources
Jan 4, 2023
64b056b
merged
Jan 4, 2023
3d3827d
Merge remote-tracking branch 'origin/master' into gc-orphaned-resources
Jan 7, 2023
866ecf1
merge fix
Jan 7, 2023
c574584
PtGradientCollector from master
Jan 7, 2023
c2889a1
Merge remote-tracking branch 'origin/master' into gc-orphaned-resources
Jan 8, 2023
ed138be
Merge remote-tracking branch 'origin/master' into gc-orphaned-resources
Jan 9, 2023
13241f9
Revert "A temporary solution to issue 2210 (#2304)"
Jan 9, 2023
664210c
fixed bug: ignore if array is already closed
Jan 9, 2023
be300cc
Revert "some inheritance opening needed in a particular project (#2231)"
Jan 9, 2023
b30c95a
removed two added methods from NDManager that are not necessary
Jan 9, 2023
a181614
introduced threadLocal reference queues
Jan 9, 2023
e30b771
Revert "Revert "some inheritance opening needed in a particular proje…
Jan 9, 2023
9b2f9bb
here I reverted to much
Jan 9, 2023
dc5ce14
Merge remote-tracking branch 'origin/master' into gc-orphaned-resources
Jan 9, 2023
6cf606e
added a method gc() to NDManager which explicitly calls checkQueue on…
Jan 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 95 additions & 26 deletions api/src/main/java/ai/djl/ndarray/BaseNDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@
import java.nio.ShortBuffer;
import java.nio.charset.Charset;
import java.nio.file.Path;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/** {@code BaseNDManager} is the default implementation of {@link NDManager}. */
public abstract class BaseNDManager implements NDManager {
Expand Down Expand Up @@ -304,6 +307,30 @@ public Device getDevice() {
return device;
}

/** {@inheritDoc} */
@Override
public List<NDArray> getManagedArrays() {
return Stream.concat(
// Main resources
resources.values().stream()
.flatMap(
r -> {
if (r instanceof NDResource) {
return ((NDResource) r)
.getResourceNDArrays().stream();
} else if (r instanceof NDManager) {
return ((NDManager) r).getManagedArrays().stream();
} else {
return Stream.empty();
}
}),

// Temp resouces
tempResources.values().stream()
.flatMap(tr -> tr.resource.getResourceNDArrays().stream()))
.collect(Collectors.toList());
}

/** {@inheritDoc} */
@Override
public String toString() {
Expand All @@ -321,9 +348,6 @@ public String toString() {
/** {@inheritDoc} */
@Override
public synchronized void attachInternal(String resourceId, AutoCloseable resource) {
if (this instanceof SystemNDManager) {
return;
}
if (capped.get()) {
throw new IllegalStateException("NDManager is capped for addition of resources.");
}
Expand All @@ -333,9 +357,6 @@ public synchronized void attachInternal(String resourceId, AutoCloseable resourc
/** {@inheritDoc} */
@Override
public synchronized void attachUncappedInternal(String resourceId, AutoCloseable resource) {
if (this instanceof SystemNDManager) {
return;
}
if (closed.get()) {
throw new IllegalStateException("NDManager has been closed already.");
}
Expand All @@ -362,7 +383,8 @@ public synchronized void attachUncappedInternal(String resourceId, AutoCloseable
public void tempAttachInternal(
NDManager originalManager, String resourceId, NDResource resource) {
if (this instanceof SystemNDManager) {
return;
throw new IllegalStateException(
"System manager cannot be temp attached because it can't be closed..");
}
if (closed.get()) {
throw new IllegalStateException("NDManager has been closed already.");
Expand All @@ -373,9 +395,6 @@ public void tempAttachInternal(
/** {@inheritDoc} */
@Override
public synchronized void detachInternal(String resourceId) {
if (this instanceof SystemNDManager) {
return;
}
if (closed.get()) {
// This may happen in the middle of BaseNDManager.close()
return;
Expand All @@ -402,26 +421,13 @@ public NDList invoke(String operation, NDList src, PairList<String, ?> params) {
throw new UnsupportedOperationException("Not supported!");
}

/** {@inheritDoc} */
@Override
public void zeroGradients() {
for (AutoCloseable res : resources.values()) {
if (res instanceof NDManager) {
((NDManager) res).zeroGradients();
} else if (res instanceof NDArray) {
NDArray array = (NDArray) res;
if (array.hasGradient()) {
array.getGradient().subi(array.getGradient());
}
}
}
}

/** {@inheritDoc} */
@Override
public void close() {
if (this instanceof SystemNDManager) {
return;
throw new IllegalStateException(
"The SystemNDManager can not be closed. It is global and lives for the duration"
+ " of the process");
}
if (!closed.getAndSet(true)) {
for (AutoCloseable closeable : resources.values()) {
Expand Down Expand Up @@ -463,6 +469,69 @@ public void debugDump(int level) {
}
}

/**
* Prints information about this {@link NDManager} and all sub-managers to the console.
*
* @param level the level of this {@link NDManager} in the hierarchy
*/
public void debugDumpDetailed(int level) {
StringBuilder sb = new StringBuilder(100);
for (int i = 0; i < level; ++i) {
sb.append(" ");
}
sb.append("\\--- NDManager(")
.append(uid.substring(24))
.append(", ")
.append(device)
.append(") resource count: ")
.append(resources.size());

System.out.println(sb); // NOPMD
for (AutoCloseable c : resources.values()) {
if (c instanceof NDManager) {
((BaseNDManager) c).debugDumpDetailed(level + 1);
} else if (c instanceof NDArray) {
StringBuilder sb2 = new StringBuilder(100);
for (int i = 0; i < level + 1; ++i) {
sb2.append(" ");
}
sb2.append(
"\\--- NDArray("
+ ((NDArray) c).getUid()
+ ", Shape"
+ ((NDArray) c).getShape()
+ ")");
System.out.println(sb2); // NOPMD
} else if (c instanceof NDResource) {
StringBuilder sb2 = new StringBuilder(100);
for (int i = 0; i < level + 1; ++i) {
sb2.append(" ");
}
sb2.append("\\--- other NDResource");
System.out.println(sb2); // NOPMD
}
}
}

/**
* Returns the number of {@link NDArray} in the hierarchy of this {@link NDManager}.
*
* @return return the number of {@link NDArray} in the hierarchy of this {@link NDManager}
*/
public int debugCountNDArrays() {
int count = 0;
for (AutoCloseable c : resources.values()) {
if (c instanceof BaseNDManager) {
count += ((BaseNDManager) c).debugCountNDArrays();
} else if (c instanceof NDArray) {
count++;
} else if (c instanceof NDList) {
count += ((NDList) c).size();
}
}
return count;
}

NDManager getAlternativeManager() {
return alternativeManager;
}
Expand Down
8 changes: 8 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
Expand Down Expand Up @@ -4685,6 +4687,12 @@ default NDArray countNonzero(int axis) {
*/
NDArray erfinv();

/** {@inheritDoc} */
@Override
default List<NDArray> getResourceNDArrays() {
return Collections.singletonList(this);
}

/**
* Returns an internal representative of Native {@code NDArray}.
*
Expand Down
7 changes: 7 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDList.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;
Expand Down Expand Up @@ -269,6 +270,12 @@ public NDManager getManager() {
return head().getManager();
}

/** {@inheritDoc} */
@Override
public List<NDArray> getResourceNDArrays() {
return this;
}

/** {@inheritDoc} */
@Override
public void attach(NDManager manager) {
Expand Down
28 changes: 25 additions & 3 deletions api/src/main/java/ai/djl/ndarray/NDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ai.djl.Device;
import ai.djl.engine.Engine;
import ai.djl.engine.EngineException;
import ai.djl.ndarray.gc.NDArrayProxyMaker;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.Translator;
Expand All @@ -34,6 +35,7 @@
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.List;

/**
* NDArray managers are used to create <I>NDArrays</I> (n-dimensional array on native engine).
Expand Down Expand Up @@ -740,6 +742,22 @@ default NDList load(Path path, Device device) {
return newSubManager(device).load(path);
}

/**
* Returns the {@link NDArrayProxyMaker}.
*
* @return the {@link NDArrayProxyMaker}
*/
default NDArrayProxyMaker getProxyMaker() {
throw new UnsupportedOperationException("Not supported");
}

/**
* Checks the referenceQueue for NDArrays that are garbage collected by Java GC and closes them.
*/
default void gc() {
throw new UnsupportedOperationException("Not supported");
}

/**
* Sets the name for the NDManager.
*
Expand Down Expand Up @@ -1534,6 +1552,13 @@ default NDArray hanningWindow(long numPoints) {
*/
Device getDevice();

/**
* Returns all {@link NDArray}s managed by this manager (including recursively).
*
* @return all {@link NDArray}s managed by this manager (including recursively)
*/
List<NDArray> getManagedArrays();

/**
* Attaches a resource to this {@code NDManager}.
*
Expand Down Expand Up @@ -1668,9 +1693,6 @@ default void tempAttachAll(NDResource... resources) {
*/
Engine getEngine();

/** Sets all the gradients within the NDManager to zero. */
void zeroGradients();

/** {@inheritDoc} */
@Override
void close();
Expand Down
9 changes: 9 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDResource.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
*/
package ai.djl.ndarray;

import java.util.List;

/** An object which is managed by an {@link NDManager} and tracks the manager it is attached to. */
public interface NDResource extends AutoCloseable {

Expand All @@ -22,6 +24,13 @@ public interface NDResource extends AutoCloseable {
*/
NDManager getManager();

/**
* Returns the {@link NDArray} or {@link NDArray}s contained within this resource.
*
* @return the {@link NDArray} or {@link NDArray}s contained within this resource
*/
List<NDArray> getResourceNDArrays();

/**
* Attaches this {@link NDResource} to the specified {@link NDManager}.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.ndarray.gc;

import ai.djl.ndarray.NDArray;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

/** {@code DynamicInvocationHandler} implements the {@link InvocationHandler}. */
public class DynamicInvocationHandler implements InvocationHandler {

private static final Logger logger = LoggerFactory.getLogger(DynamicInvocationHandler.class);

WeakHashMapWrapper<String, NDArray> map;
String uid;

NDArrayProxyMaker ndArrayProxyMaker;

/**
* Creates a new instance of {@code DynamicInvocationHandler}.
*
* @param uid the uid
* @param map the map
* @param ndArrayProxyMaker the ndArrayProxyMaker
*/
public DynamicInvocationHandler(
String uid,
WeakHashMapWrapper<String, NDArray> map,
NDArrayProxyMaker ndArrayProxyMaker) {
this.map = map;
this.uid = uid;
this.ndArrayProxyMaker = ndArrayProxyMaker;
}

/** {@inheritDoc} */
@Override
public Object invoke(Object proxy, Method method, Object[] args) {

if ("getNumOfNDArraysInGCMap".equals(method.getName())) {
return this.map.size();
}
if ("getImplementation".equals(method.getName())) {
return map.get(uid);
}
Object result;
try {
NDArray ndArray = map.get(uid);
if (ndArray == null) {
logger.error("no nDArray found for uid: {}", uid);
throw new GCRuntimeException(
"no nDArray could be found for uid: "
+ uid
+ ". Consider calling the methods of a particular nDArray only from"
+ " one thread or do not switch on garbage collection.");
}
result = method.invoke(ndArray, args);
} catch (IllegalAccessException | InvocationTargetException e) {
logger.error("Error invoking method", e);
throw new GCRuntimeException(e);
}

return result;
}
}
Loading