Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

memory leak in Pt indexing #2300

Merged
merged 5 commits into from
Jan 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -364,119 +364,134 @@ public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index, PtNDManager m
}
List<NDIndexElement> indices = index.getIndices();
long torchIndexHandle = PyTorchLibrary.LIB.torchIndexInit(indices.size());
ListIterator<NDIndexElement> it = indices.listIterator();
while (it.hasNext()) {
if (it.nextIndex() == index.getEllipsisIndex()) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true);
}
try {
// Index aggregation
ListIterator<NDIndexElement> it = indices.listIterator();
while (it.hasNext()) {
if (it.nextIndex() == index.getEllipsisIndex()) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true);
}

NDIndexElement elem = it.next();
if (elem instanceof NDIndexNull) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, false);
} else if (elem instanceof NDIndexSlice) {
Long min = ((NDIndexSlice) elem).getMin();
Long max = ((NDIndexSlice) elem).getMax();
Long step = ((NDIndexSlice) elem).getStep();
int nullSliceBin = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0);
// nullSliceBin encodes whether the slice (min, max) is null:
// is_null == 1, ! is_null == 0;
// 0b11 == 3, 0b10 = 2, ...
PyTorchLibrary.LIB.torchIndexAppendSlice(
torchIndexHandle,
min == null ? 0 : min,
max == null ? 0 : max,
step == null ? 1 : step,
nullSliceBin);
} else if (elem instanceof NDIndexAll) {
PyTorchLibrary.LIB.torchIndexAppendSlice(torchIndexHandle, 0, 0, 1, 3);
} else if (elem instanceof NDIndexFixed) {
PyTorchLibrary.LIB.torchIndexAppendFixed(
torchIndexHandle, ((NDIndexFixed) elem).getIndex());
} else if (elem instanceof NDIndexBooleans) {
PtNDArray indexArr = (PtNDArray) ((NDIndexBooleans) elem).getIndex();
PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, indexArr.getHandle());
} else if (elem instanceof NDIndexTake) {
PtNDArray indexArr = manager.from(((NDIndexTake) elem).getIndex());
if (indexArr.getDataType() != DataType.INT64) {
indexArr = indexArr.toType(DataType.INT64, true);
NDIndexElement elem = it.next();
if (elem instanceof NDIndexNull) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, false);
} else if (elem instanceof NDIndexSlice) {
Long min = ((NDIndexSlice) elem).getMin();
Long max = ((NDIndexSlice) elem).getMax();
Long step = ((NDIndexSlice) elem).getStep();
int nullSliceBinary = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0);
// nullSliceBinary encodes whether the slice end {min, max} is null:
// is_null == 1, ! is_null == 0;
// 0b11 == 3, 0b10 = 2, ...
// If {min, max} is null, then its value is ineffective, thus set to -1.
PyTorchLibrary.LIB.torchIndexAppendSlice(
torchIndexHandle,
min == null ? -1 : min,
max == null ? -1 : max,
step == null ? 1 : step,
nullSliceBinary);
} else if (elem instanceof NDIndexAll) {
PyTorchLibrary.LIB.torchIndexAppendSlice(torchIndexHandle, -1, -1, 1, 3);
} else if (elem instanceof NDIndexFixed) {
PyTorchLibrary.LIB.torchIndexAppendFixed(
torchIndexHandle, ((NDIndexFixed) elem).getIndex());
} else if (elem instanceof NDIndexBooleans) {
PtNDArray indexArr = (PtNDArray) ((NDIndexBooleans) elem).getIndex();
PyTorchLibrary.LIB.torchIndexAppendArray(
torchIndexHandle, indexArr.getHandle());
} else if (elem instanceof NDIndexTake) {
PtNDArray indexArr = manager.from(((NDIndexTake) elem).getIndex());
if (indexArr.getDataType() != DataType.INT64) {
indexArr = indexArr.toType(DataType.INT64, true);
}
PyTorchLibrary.LIB.torchIndexAppendArray(
torchIndexHandle, indexArr.getHandle());
} else if (elem instanceof NDIndexPick) {
// Backward compatible
NDIndexFullPick fullPick =
NDIndexFullPick.fromIndex(index, ndArray.getShape()).get();
return pick(ndArray, manager.from(fullPick.getIndices()), fullPick.getAxis());
}
PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, indexArr.getHandle());
} else if (elem instanceof NDIndexPick) {
// Backward compatible
NDIndexFullPick fullPick =
NDIndexFullPick.fromIndex(index, ndArray.getShape()).get();
return pick(ndArray, manager.from(fullPick.getIndices()), fullPick.getAxis());
}
if (indices.size() == index.getEllipsisIndex()) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true);
}
long ret = PyTorchLibrary.LIB.torchIndexAdvGet(ndArray.getHandle(), torchIndexHandle);
return new PtNDArray(manager, ret);
} finally {
PyTorchLibrary.LIB.torchDeleteIndex(torchIndexHandle);
}
if (indices.size() == index.getEllipsisIndex()) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true);
}

return new PtNDArray(
manager,
PyTorchLibrary.LIB.torchIndexAdvGet(ndArray.getHandle(), torchIndexHandle));
}

@SuppressWarnings("OptionalGetWithoutIsPresent")
public static void indexAdvPut(PtNDArray ndArray, NDIndex index, PtNDArray data) {
if (ndArray == null) {
return;
}

// Index aggregation
List<NDIndexElement> indices = index.getIndices();
long torchIndexHandle = PyTorchLibrary.LIB.torchIndexInit(indices.size());
ListIterator<NDIndexElement> it = indices.listIterator();
while (it.hasNext()) {
if (it.nextIndex() == index.getEllipsisIndex()) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true);
}
try {
// Index aggregation
ListIterator<NDIndexElement> it = indices.listIterator();
while (it.hasNext()) {
if (it.nextIndex() == index.getEllipsisIndex()) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true);
}

NDIndexElement elem = it.next();
if (elem instanceof NDIndexNull) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, false);
} else if (elem instanceof NDIndexSlice) {
Long min = ((NDIndexSlice) elem).getMin();
Long max = ((NDIndexSlice) elem).getMax();
Long step = ((NDIndexSlice) elem).getStep();
int nullSliceBin = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0);
// nullSliceBin encodes whether the slice (min, max) is null:
// is_null == 1, ! is_null == 0;
// 0b11 == 3, 0b10 = 2, ...
PyTorchLibrary.LIB.torchIndexAppendSlice(
torchIndexHandle,
min == null ? 0 : min,
max == null ? 0 : max,
step == null ? 1 : step,
nullSliceBin);
} else if (elem instanceof NDIndexAll) {
PyTorchLibrary.LIB.torchIndexAppendSlice(torchIndexHandle, 0, 0, 1, 3);
} else if (elem instanceof NDIndexFixed) {
PyTorchLibrary.LIB.torchIndexAppendFixed(
torchIndexHandle, ((NDIndexFixed) elem).getIndex());
} else if (elem instanceof NDIndexBooleans) {
PtNDArray indexArr = (PtNDArray) ((NDIndexBooleans) elem).getIndex();
PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, indexArr.getHandle());
} else if (elem instanceof NDIndexTake) {
PtNDArray indexArr = (PtNDArray) ((NDIndexTake) elem).getIndex();
if (indexArr.getDataType() != DataType.INT64) {
indexArr = indexArr.toType(DataType.INT64, true);
NDIndexElement elem = it.next();
if (elem instanceof NDIndexNull) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, false);
} else if (elem instanceof NDIndexSlice) {
Long min = ((NDIndexSlice) elem).getMin();
Long max = ((NDIndexSlice) elem).getMax();
Long step = ((NDIndexSlice) elem).getStep();
int nullSliceBinary = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0);
// nullSliceBinary encodes whether the slice end {min, max} is null:
// is_null == 1, ! is_null == 0;
// 0b11 == 3, 0b10 = 2, ...
// If {min, max} is null, then its value is ineffective, thus set to -1.
PyTorchLibrary.LIB.torchIndexAppendSlice(
torchIndexHandle,
min == null ? -1 : min,
max == null ? -1 : max,
step == null ? 1 : step,
nullSliceBinary);
} else if (elem instanceof NDIndexAll) {
PyTorchLibrary.LIB.torchIndexAppendSlice(torchIndexHandle, -1, -1, 1, 3);
} else if (elem instanceof NDIndexFixed) {
PyTorchLibrary.LIB.torchIndexAppendFixed(
torchIndexHandle, ((NDIndexFixed) elem).getIndex());
} else if (elem instanceof NDIndexBooleans) {
PtNDArray indexArr = (PtNDArray) ((NDIndexBooleans) elem).getIndex();
PyTorchLibrary.LIB.torchIndexAppendArray(
torchIndexHandle, indexArr.getHandle());
} else if (elem instanceof NDIndexTake) {
PtNDArray indexArr = (PtNDArray) ((NDIndexTake) elem).getIndex();
if (indexArr.getDataType() != DataType.INT64) {
indexArr = indexArr.toType(DataType.INT64, true);
}
PyTorchLibrary.LIB.torchIndexAppendArray(
torchIndexHandle, indexArr.getHandle());
} else if (elem instanceof NDIndexPick) {
// Backward compatible
NDIndexFullPick fullPick =
NDIndexFullPick.fromIndex(index, ndArray.getShape()).get();
pick(
ndArray,
ndArray.getManager().from(fullPick.getIndices()),
fullPick.getAxis());
return;
}
PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, indexArr.getHandle());
} else if (elem instanceof NDIndexPick) {
// Backward compatible
NDIndexFullPick fullPick =
NDIndexFullPick.fromIndex(index, ndArray.getShape()).get();
pick(ndArray, ndArray.getManager().from(fullPick.getIndices()), fullPick.getAxis());
return;
}
}
if (indices.size() == index.getEllipsisIndex()) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true);
}
if (indices.size() == index.getEllipsisIndex()) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true);
}

PyTorchLibrary.LIB.torchIndexAdvPut(
ndArray.getHandle(), torchIndexHandle, data.getHandle());
PyTorchLibrary.LIB.torchIndexAdvPut(
ndArray.getHandle(), torchIndexHandle, data.getHandle());
} finally {
PyTorchLibrary.LIB.torchDeleteIndex(torchIndexHandle);
}
}

public static void indexSet(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ native void torchIndexPut(

native void torchDeleteTensor(long handle);

native void torchDeleteIndex(long handle);

native void torchDeleteModule(long handle);

native void torchDeleteIValue(long handle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,14 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDeleteTensor(
API_END()
}

JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDeleteIndex(
JNIEnv* env, jobject jthis, jlong jtorch_index_handle) {
API_BEGIN()
auto* index_ptr = reinterpret_cast<std::vector<torch::indexing::TensorIndex>*>(jtorch_index_handle);
delete index_ptr;
API_END()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchToSparse(
JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
Expand Down