From d3ab825caf605b9436c7a870d8617be00ec46309 Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Fri, 6 Jan 2023 09:19:44 -0800 Subject: [PATCH 1/5] memory leak in Pt indexing --- .../src/main/java/ai/djl/pytorch/jni/JniUtils.java | 10 ++++++---- .../native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc | 1 + 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index f0b2d183162..3d3854c6b9a 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -363,6 +363,7 @@ public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index, PtNDManager m return ndArray; } List indices = index.getIndices(); + // Native resources allocated here will be closed inside torchIndexAdvGet long torchIndexHandle = PyTorchLibrary.LIB.torchIndexInit(indices.size()); ListIterator it = indices.listIterator(); while (it.hasNext()) { @@ -378,17 +379,18 @@ public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index, PtNDManager m Long max = ((NDIndexSlice) elem).getMax(); Long step = ((NDIndexSlice) elem).getStep(); int nullSliceBin = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0); - // nullSliceBin encodes whether the slice (min, max) is null: + // nullSliceBin encodes whether the slice ends {min, max} is null: // is_null == 1, ! is_null == 0; // 0b11 == 3, 0b10 = 2, ... + // If {min, max} is null, then its value is ineffective, thus set to -1. PyTorchLibrary.LIB.torchIndexAppendSlice( torchIndexHandle, - min == null ? 0 : min, - max == null ? 0 : max, + min == null ? -1 : min, + max == null ? -1 : max, step == null ? 1 : step, nullSliceBin); } else if (elem instanceof NDIndexAll) { - PyTorchLibrary.LIB.torchIndexAppendSlice(torchIndexHandle, 0, 0, 1, 3); + PyTorchLibrary.LIB.torchIndexAppendSlice(torchIndexHandle, -1, -1, 1, 3); } else if (elem instanceof NDIndexFixed) { PyTorchLibrary.LIB.torchIndexAppendFixed( torchIndexHandle, ((NDIndexFixed) elem).getIndex()); diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc index f81a66e448d..650c74ae6a0 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc @@ -135,6 +135,7 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAdvGet( const auto* tensor_ptr = reinterpret_cast(jhandle); auto* index_ptr = reinterpret_cast*>(jtorch_index_handle); torch::Tensor* ret_ptr = new torch::Tensor(tensor_ptr->index(*index_ptr)); + delete index_ptr; return reinterpret_cast(ret_ptr); API_END_RETURN() } From 8e92e56eb2ba1f283357a74dd2553b9e7939ec9b Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Fri, 6 Jan 2023 09:23:31 -0800 Subject: [PATCH 2/5] doc --- .../src/main/java/ai/djl/pytorch/jni/JniUtils.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index 3d3854c6b9a..b0b13d96d7c 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -379,7 +379,7 @@ public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index, PtNDManager m Long max = ((NDIndexSlice) elem).getMax(); Long step = ((NDIndexSlice) elem).getStep(); int nullSliceBin = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0); - // nullSliceBin encodes whether the slice ends {min, max} is null: + // nullSliceBin encodes whether the slice end {min, max} is null: // is_null == 1, ! is_null == 0; // 0b11 == 3, 0b10 = 2, ... // If {min, max} is null, then its value is ineffective, thus set to -1. From 5288facd56736c417e03e9eb825dae34980be9f6 Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Fri, 6 Jan 2023 09:59:47 -0800 Subject: [PATCH 3/5] move native resource deletion to java --- .../java/ai/djl/pytorch/jni/JniUtils.java | 32 +++++++++++-------- .../ai/djl/pytorch/jni/PyTorchLibrary.java | 2 ++ ...i_djl_pytorch_jni_PyTorchLibrary_tensor.cc | 9 +++++- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index b0b13d96d7c..60c102288cb 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -363,7 +363,6 @@ public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index, PtNDManager m return ndArray; } List indices = index.getIndices(); - // Native resources allocated here will be closed inside torchIndexAdvGet long torchIndexHandle = PyTorchLibrary.LIB.torchIndexInit(indices.size()); ListIterator it = indices.listIterator(); while (it.hasNext()) { @@ -378,8 +377,8 @@ public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index, PtNDManager m Long min = ((NDIndexSlice) elem).getMin(); Long max = ((NDIndexSlice) elem).getMax(); Long step = ((NDIndexSlice) elem).getStep(); - int nullSliceBin = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0); - // nullSliceBin encodes whether the slice end {min, max} is null: + int nullSliceBinary = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0); + // nullSliceBinary encodes whether the slice end {min, max} is null: // is_null == 1, ! is_null == 0; // 0b11 == 3, 0b10 = 2, ... // If {min, max} is null, then its value is ineffective, thus set to -1. @@ -388,7 +387,7 @@ public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index, PtNDManager m min == null ? -1 : min, max == null ? -1 : max, step == null ? 1 : step, - nullSliceBin); + nullSliceBinary); } else if (elem instanceof NDIndexAll) { PyTorchLibrary.LIB.torchIndexAppendSlice(torchIndexHandle, -1, -1, 1, 3); } else if (elem instanceof NDIndexFixed) { @@ -414,9 +413,12 @@ public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index, PtNDManager m PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true); } - return new PtNDArray( - manager, - PyTorchLibrary.LIB.torchIndexAdvGet(ndArray.getHandle(), torchIndexHandle)); + PtNDArray ret = + new PtNDArray( + manager, + PyTorchLibrary.LIB.torchIndexAdvGet(ndArray.getHandle(), torchIndexHandle)); + PyTorchLibrary.LIB.torchDeleteIndex(torchIndexHandle); + return ret; } @SuppressWarnings("OptionalGetWithoutIsPresent") @@ -424,7 +426,6 @@ public static void indexAdvPut(PtNDArray ndArray, NDIndex index, PtNDArray data) if (ndArray == null) { return; } - // Index aggregation List indices = index.getIndices(); long torchIndexHandle = PyTorchLibrary.LIB.torchIndexInit(indices.size()); @@ -441,18 +442,19 @@ public static void indexAdvPut(PtNDArray ndArray, NDIndex index, PtNDArray data) Long min = ((NDIndexSlice) elem).getMin(); Long max = ((NDIndexSlice) elem).getMax(); Long step = ((NDIndexSlice) elem).getStep(); - int nullSliceBin = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0); - // nullSliceBin encodes whether the slice (min, max) is null: + int nullSliceBinary = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0); + // nullSliceBinary encodes whether the slice end {min, max} is null: // is_null == 1, ! is_null == 0; // 0b11 == 3, 0b10 = 2, ... + // If {min, max} is null, then its value is ineffective, thus set to -1. PyTorchLibrary.LIB.torchIndexAppendSlice( torchIndexHandle, - min == null ? 0 : min, - max == null ? 0 : max, + min == null ? -1 : min, + max == null ? -1 : max, step == null ? 1 : step, - nullSliceBin); + nullSliceBinary); } else if (elem instanceof NDIndexAll) { - PyTorchLibrary.LIB.torchIndexAppendSlice(torchIndexHandle, 0, 0, 1, 3); + PyTorchLibrary.LIB.torchIndexAppendSlice(torchIndexHandle, -1, -1, 1, 3); } else if (elem instanceof NDIndexFixed) { PyTorchLibrary.LIB.torchIndexAppendFixed( torchIndexHandle, ((NDIndexFixed) elem).getIndex()); @@ -479,6 +481,8 @@ public static void indexAdvPut(PtNDArray ndArray, NDIndex index, PtNDArray data) PyTorchLibrary.LIB.torchIndexAdvPut( ndArray.getHandle(), torchIndexHandle, data.getHandle()); + + PyTorchLibrary.LIB.torchDeleteIndex(torchIndexHandle); } public static void indexSet( diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index f6c4b3d2c20..c28910012cc 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -214,6 +214,8 @@ native void torchIndexPut( native void torchDeleteTensor(long handle); + native void torchDeleteIndex(long handle); + native void torchDeleteModule(long handle); native void torchDeleteIValue(long handle); diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc index 650c74ae6a0..06349f80dbf 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc @@ -135,7 +135,6 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAdvGet( const auto* tensor_ptr = reinterpret_cast(jhandle); auto* index_ptr = reinterpret_cast*>(jtorch_index_handle); torch::Tensor* ret_ptr = new torch::Tensor(tensor_ptr->index(*index_ptr)); - delete index_ptr; return reinterpret_cast(ret_ptr); API_END_RETURN() } @@ -312,6 +311,14 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDeleteTensor( API_END() } +JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDeleteIndex( + JNIEnv* env, jobject jthis, jlong jtorch_index_handle) { + API_BEGIN() + auto* index_ptr = reinterpret_cast*>(jtorch_index_handle); + delete index_ptr; + API_END() +} + JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchToSparse( JNIEnv* env, jobject jthis, jlong jhandle) { API_BEGIN() From 919ef93d7989aa73fb70e6530a768dfb83dbabaf Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Fri, 6 Jan 2023 10:30:15 -0800 Subject: [PATCH 4/5] try finally block --- .../java/ai/djl/pytorch/jni/JniUtils.java | 218 ++++++++++-------- 1 file changed, 116 insertions(+), 102 deletions(-) diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index 60c102288cb..431568da61a 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -362,63 +362,70 @@ public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index, PtNDManager m if (ndArray == null) { return ndArray; } + PtNDArray result; List indices = index.getIndices(); long torchIndexHandle = PyTorchLibrary.LIB.torchIndexInit(indices.size()); - ListIterator it = indices.listIterator(); - while (it.hasNext()) { - if (it.nextIndex() == index.getEllipsisIndex()) { - PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true); - } + try { + // Index aggregation + ListIterator it = indices.listIterator(); + while (it.hasNext()) { + if (it.nextIndex() == index.getEllipsisIndex()) { + PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true); + } - NDIndexElement elem = it.next(); - if (elem instanceof NDIndexNull) { - PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, false); - } else if (elem instanceof NDIndexSlice) { - Long min = ((NDIndexSlice) elem).getMin(); - Long max = ((NDIndexSlice) elem).getMax(); - Long step = ((NDIndexSlice) elem).getStep(); - int nullSliceBinary = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0); - // nullSliceBinary encodes whether the slice end {min, max} is null: - // is_null == 1, ! is_null == 0; - // 0b11 == 3, 0b10 = 2, ... - // If {min, max} is null, then its value is ineffective, thus set to -1. - PyTorchLibrary.LIB.torchIndexAppendSlice( - torchIndexHandle, - min == null ? -1 : min, - max == null ? -1 : max, - step == null ? 1 : step, - nullSliceBinary); - } else if (elem instanceof NDIndexAll) { - PyTorchLibrary.LIB.torchIndexAppendSlice(torchIndexHandle, -1, -1, 1, 3); - } else if (elem instanceof NDIndexFixed) { - PyTorchLibrary.LIB.torchIndexAppendFixed( - torchIndexHandle, ((NDIndexFixed) elem).getIndex()); - } else if (elem instanceof NDIndexBooleans) { - PtNDArray indexArr = (PtNDArray) ((NDIndexBooleans) elem).getIndex(); - PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, indexArr.getHandle()); - } else if (elem instanceof NDIndexTake) { - PtNDArray indexArr = manager.from(((NDIndexTake) elem).getIndex()); - if (indexArr.getDataType() != DataType.INT64) { - indexArr = indexArr.toType(DataType.INT64, true); + NDIndexElement elem = it.next(); + if (elem instanceof NDIndexNull) { + PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, false); + } else if (elem instanceof NDIndexSlice) { + Long min = ((NDIndexSlice) elem).getMin(); + Long max = ((NDIndexSlice) elem).getMax(); + Long step = ((NDIndexSlice) elem).getStep(); + int nullSliceBinary = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0); + // nullSliceBinary encodes whether the slice end {min, max} is null: + // is_null == 1, ! is_null == 0; + // 0b11 == 3, 0b10 = 2, ... + // If {min, max} is null, then its value is ineffective, thus set to -1. + PyTorchLibrary.LIB.torchIndexAppendSlice( + torchIndexHandle, + min == null ? -1 : min, + max == null ? -1 : max, + step == null ? 1 : step, + nullSliceBinary); + } else if (elem instanceof NDIndexAll) { + PyTorchLibrary.LIB.torchIndexAppendSlice(torchIndexHandle, -1, -1, 1, 3); + } else if (elem instanceof NDIndexFixed) { + PyTorchLibrary.LIB.torchIndexAppendFixed( + torchIndexHandle, ((NDIndexFixed) elem).getIndex()); + } else if (elem instanceof NDIndexBooleans) { + PtNDArray indexArr = (PtNDArray) ((NDIndexBooleans) elem).getIndex(); + PyTorchLibrary.LIB.torchIndexAppendArray( + torchIndexHandle, indexArr.getHandle()); + } else if (elem instanceof NDIndexTake) { + PtNDArray indexArr = manager.from(((NDIndexTake) elem).getIndex()); + if (indexArr.getDataType() != DataType.INT64) { + indexArr = indexArr.toType(DataType.INT64, true); + } + PyTorchLibrary.LIB.torchIndexAppendArray( + torchIndexHandle, indexArr.getHandle()); + } else if (elem instanceof NDIndexPick) { + // Backward compatible + NDIndexFullPick fullPick = + NDIndexFullPick.fromIndex(index, ndArray.getShape()).get(); + return pick(ndArray, manager.from(fullPick.getIndices()), fullPick.getAxis()); } - PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, indexArr.getHandle()); - } else if (elem instanceof NDIndexPick) { - // Backward compatible - NDIndexFullPick fullPick = - NDIndexFullPick.fromIndex(index, ndArray.getShape()).get(); - return pick(ndArray, manager.from(fullPick.getIndices()), fullPick.getAxis()); } + if (indices.size() == index.getEllipsisIndex()) { + PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true); + } + result = + new PtNDArray( + manager, + PyTorchLibrary.LIB.torchIndexAdvGet( + ndArray.getHandle(), torchIndexHandle)); + } finally { + PyTorchLibrary.LIB.torchDeleteIndex(torchIndexHandle); } - if (indices.size() == index.getEllipsisIndex()) { - PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true); - } - - PtNDArray ret = - new PtNDArray( - manager, - PyTorchLibrary.LIB.torchIndexAdvGet(ndArray.getHandle(), torchIndexHandle)); - PyTorchLibrary.LIB.torchDeleteIndex(torchIndexHandle); - return ret; + return result; } @SuppressWarnings("OptionalGetWithoutIsPresent") @@ -426,63 +433,70 @@ public static void indexAdvPut(PtNDArray ndArray, NDIndex index, PtNDArray data) if (ndArray == null) { return; } - // Index aggregation List indices = index.getIndices(); long torchIndexHandle = PyTorchLibrary.LIB.torchIndexInit(indices.size()); - ListIterator it = indices.listIterator(); - while (it.hasNext()) { - if (it.nextIndex() == index.getEllipsisIndex()) { - PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true); - } + try { + // Index aggregation + ListIterator it = indices.listIterator(); + while (it.hasNext()) { + if (it.nextIndex() == index.getEllipsisIndex()) { + PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true); + } - NDIndexElement elem = it.next(); - if (elem instanceof NDIndexNull) { - PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, false); - } else if (elem instanceof NDIndexSlice) { - Long min = ((NDIndexSlice) elem).getMin(); - Long max = ((NDIndexSlice) elem).getMax(); - Long step = ((NDIndexSlice) elem).getStep(); - int nullSliceBinary = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0); - // nullSliceBinary encodes whether the slice end {min, max} is null: - // is_null == 1, ! is_null == 0; - // 0b11 == 3, 0b10 = 2, ... - // If {min, max} is null, then its value is ineffective, thus set to -1. - PyTorchLibrary.LIB.torchIndexAppendSlice( - torchIndexHandle, - min == null ? -1 : min, - max == null ? -1 : max, - step == null ? 1 : step, - nullSliceBinary); - } else if (elem instanceof NDIndexAll) { - PyTorchLibrary.LIB.torchIndexAppendSlice(torchIndexHandle, -1, -1, 1, 3); - } else if (elem instanceof NDIndexFixed) { - PyTorchLibrary.LIB.torchIndexAppendFixed( - torchIndexHandle, ((NDIndexFixed) elem).getIndex()); - } else if (elem instanceof NDIndexBooleans) { - PtNDArray indexArr = (PtNDArray) ((NDIndexBooleans) elem).getIndex(); - PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, indexArr.getHandle()); - } else if (elem instanceof NDIndexTake) { - PtNDArray indexArr = (PtNDArray) ((NDIndexTake) elem).getIndex(); - if (indexArr.getDataType() != DataType.INT64) { - indexArr = indexArr.toType(DataType.INT64, true); + NDIndexElement elem = it.next(); + if (elem instanceof NDIndexNull) { + PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, false); + } else if (elem instanceof NDIndexSlice) { + Long min = ((NDIndexSlice) elem).getMin(); + Long max = ((NDIndexSlice) elem).getMax(); + Long step = ((NDIndexSlice) elem).getStep(); + int nullSliceBinary = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0); + // nullSliceBinary encodes whether the slice end {min, max} is null: + // is_null == 1, ! is_null == 0; + // 0b11 == 3, 0b10 = 2, ... + // If {min, max} is null, then its value is ineffective, thus set to -1. + PyTorchLibrary.LIB.torchIndexAppendSlice( + torchIndexHandle, + min == null ? -1 : min, + max == null ? -1 : max, + step == null ? 1 : step, + nullSliceBinary); + } else if (elem instanceof NDIndexAll) { + PyTorchLibrary.LIB.torchIndexAppendSlice(torchIndexHandle, -1, -1, 1, 3); + } else if (elem instanceof NDIndexFixed) { + PyTorchLibrary.LIB.torchIndexAppendFixed( + torchIndexHandle, ((NDIndexFixed) elem).getIndex()); + } else if (elem instanceof NDIndexBooleans) { + PtNDArray indexArr = (PtNDArray) ((NDIndexBooleans) elem).getIndex(); + PyTorchLibrary.LIB.torchIndexAppendArray( + torchIndexHandle, indexArr.getHandle()); + } else if (elem instanceof NDIndexTake) { + PtNDArray indexArr = (PtNDArray) ((NDIndexTake) elem).getIndex(); + if (indexArr.getDataType() != DataType.INT64) { + indexArr = indexArr.toType(DataType.INT64, true); + } + PyTorchLibrary.LIB.torchIndexAppendArray( + torchIndexHandle, indexArr.getHandle()); + } else if (elem instanceof NDIndexPick) { + // Backward compatible + NDIndexFullPick fullPick = + NDIndexFullPick.fromIndex(index, ndArray.getShape()).get(); + pick( + ndArray, + ndArray.getManager().from(fullPick.getIndices()), + fullPick.getAxis()); + return; } - PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, indexArr.getHandle()); - } else if (elem instanceof NDIndexPick) { - // Backward compatible - NDIndexFullPick fullPick = - NDIndexFullPick.fromIndex(index, ndArray.getShape()).get(); - pick(ndArray, ndArray.getManager().from(fullPick.getIndices()), fullPick.getAxis()); - return; } - } - if (indices.size() == index.getEllipsisIndex()) { - PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true); - } - - PyTorchLibrary.LIB.torchIndexAdvPut( - ndArray.getHandle(), torchIndexHandle, data.getHandle()); + if (indices.size() == index.getEllipsisIndex()) { + PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true); + } - PyTorchLibrary.LIB.torchDeleteIndex(torchIndexHandle); + PyTorchLibrary.LIB.torchIndexAdvPut( + ndArray.getHandle(), torchIndexHandle, data.getHandle()); + } finally { + PyTorchLibrary.LIB.torchDeleteIndex(torchIndexHandle); + } } public static void indexSet( From 085dacda7f29e92f2df70f30bd970ecb34e22a98 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Fri, 6 Jan 2023 11:24:46 -0800 Subject: [PATCH 5/5] minor update the code style --- .../src/main/java/ai/djl/pytorch/jni/JniUtils.java | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index 431568da61a..875be451274 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -362,7 +362,6 @@ public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index, PtNDManager m if (ndArray == null) { return ndArray; } - PtNDArray result; List indices = index.getIndices(); long torchIndexHandle = PyTorchLibrary.LIB.torchIndexInit(indices.size()); try { @@ -417,15 +416,11 @@ public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index, PtNDManager m if (indices.size() == index.getEllipsisIndex()) { PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true); } - result = - new PtNDArray( - manager, - PyTorchLibrary.LIB.torchIndexAdvGet( - ndArray.getHandle(), torchIndexHandle)); + long ret = PyTorchLibrary.LIB.torchIndexAdvGet(ndArray.getHandle(), torchIndexHandle); + return new PtNDArray(manager, ret); } finally { PyTorchLibrary.LIB.torchDeleteIndex(torchIndexHandle); } - return result; } @SuppressWarnings("OptionalGetWithoutIsPresent")