Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements NDScope to automatically close NDArray in the scope #2321

Merged
merged 10 commits into from
Jan 26, 2023
31 changes: 26 additions & 5 deletions api/src/main/java/ai/djl/ndarray/NDScope.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class NDScope implements AutoCloseable {
private static final ThreadLocal<Deque<NDScope>> SCOPE_STACK =
ThreadLocal.withInitial(ArrayDeque::new);

private List<NDArray> resources;
private List<NDArrayWrapper> resources;
enpasos marked this conversation as resolved.
Show resolved Hide resolved

/** Constructs a new {@code NDScope} instance. */
public NDScope() {
Expand All @@ -46,7 +46,7 @@ public static void register(NDArray array) {
if (queue.isEmpty()) {
return;
}
queue.getLast().resources.add(array);
queue.getLast().resources.add(new NDArrayWrapper(array));
}

/**
Expand All @@ -59,14 +59,14 @@ public static void unregister(NDArray array) {
if (queue.isEmpty()) {
return;
}
queue.getLast().resources.remove(array);
queue.getLast().resources.remove(new NDArrayWrapper(array));
}

/** {@inheritDoc} */
@Override
public void close() {
for (NDArray array : resources) {
array.close();
for (NDArrayWrapper arrayWrapper : resources) {
arrayWrapper.array.close();
}
SCOPE_STACK.get().remove(this);
}
Expand All @@ -80,4 +80,25 @@ public void close() {
public void suppressNotUsedWarning() {
// do nothing
}

private static class NDArrayWrapper {
NDArray array;

public NDArrayWrapper(NDArray array) {
this.array = array;
}

@Override
public int hashCode() {
return 1;
}

@Override
public boolean equals(Object o) {
if (!(o instanceof NDArrayWrapper)) {
return false;
}
return ((NDArrayWrapper) o).array == array;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1581,13 +1581,6 @@ public String toString() {
@Override
public boolean equals(Object obj) {
if (obj instanceof NDArray) {
// do no compare content if obj is released
if (((NDArray) obj).isReleased()) {
return this == obj;
}
if (((NDArray) obj).getManager() != manager) {
return false;
}
return contentEquals((NDArray) obj);
}
return false;
Expand Down