-
Notifications
You must be signed in to change notification settings - Fork 655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pytorch] optimize memory copy cost for pytorch NDArray #3137
Conversation
engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java
Fixed
Show fixed
Hide fixed
@@ -50,6 +50,8 @@ private PyTorchLibrary() {} | |||
|
|||
native byte[] torchDataPtr(long handle); | |||
|
|||
native ByteBuffer torchDirectByteBuffer(long handle); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should just replace torchDataPtr()
with your new method.
@@ -1680,6 +1680,15 @@ public static ByteBuffer getByteBuffer(PtNDArray ndArray) { | |||
.order(ByteOrder.nativeOrder()); | |||
} | |||
|
|||
public static ByteBuffer getDirectByteBuffer(PtNDArray ndArray) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replace getByteBuffer()
with your implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still think we should return DirectBuffer in getByteBuffer()
, I agree it's a risk that potential share the native memory between two tensors. But we can mitigate the risk by overridePtNDArray.copyTo()
function, and copy the buffer if it's a DirectBuffer
@@ -45,4 +45,126 @@ public void testLargeTensor() { | |||
Assert.assertThrows(EngineException.class, array::toByteArray); | |||
} | |||
} | |||
|
|||
@Test | |||
public static void testPtTensorToLongArray() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All the tests have been covered in integration test, we don't need them here
engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc
Outdated
Show resolved
Hide resolved
Thank you for your response. What you mean is to set
which calls @Override
public void set(Buffer buffer) {
int size = Math.toIntExact(size());
DataType type = getDataType();
BaseNDManager.validateBuffer(buffer, type, size);
// TODO how do we handle the exception happened in the middle
dataRef = null;
if (buffer.isDirect() && buffer instanceof ByteBuffer) {
// If NDArray is on the GPU, it is native code responsibility to control the data life cycle
if (!getDevice().isGpu()) {
dataRef = (ByteBuffer) buffer;
}
JniUtils.set(this, (ByteBuffer) buffer);
return;
}
// int8, uint8, boolean use ByteBuffer, so need to explicitly input DataType
ByteBuffer buf = manager.allocateDirect(size * type.getNumOfBytes());
BaseNDManager.copyBuffer(buffer, buf);
// If NDArray is on the GPU, it is native code responsibility to control the data life cycle
if (!getDevice().isGpu()) {
dataRef = buf;
}
JniUtils.set(this, buf);
} By determining it is not direct memory, it can create new direct memory and perform data copying, thus achieving the purpose of deep copying. If |
Regarding the byte order issue you mentioned, I observed that in the ONNX Runtime engine, the byte order is set at the Java level: /*
* Class: ai_onnxruntime_OnnxTensor
* Method: getBuffer
* Signature: (JJ)Ljava/nio/ByteBuffer;
*/
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxTensor_getBuffer
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void) jobj; // Required JNI parameter not needed by functions that don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtValue* ortValue = (OrtValue *) handle;
JavaTensorTypeShape typeShape;
OrtErrorCode code = getTensorTypeShape(jniEnv, &typeShape, api, ortValue);
if (code == ORT_OK) {
size_t typeSize = onnxTypeSize(typeShape.onnxTypeEnum);
size_t sizeBytes = typeShape.elementCount * typeSize;
uint8_t* arr = NULL;
code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
if (code == ORT_OK) {
return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, (jlong)sizeBytes);
}
}
return NULL;
}
private ByteBuffer getBuffer() {
return getBuffer(OnnxRuntime.ortApiHandle, nativeHandle).order(ByteOrder.nativeOrder());
}
private native ByteBuffer getBuffer(long apiHandle, long nativeHandle); It sets the byte order at the Java level, following this approach, so I also set it at the Java level: public static ByteBuffer getDirectByteBuffer(PtNDArray ndArray) {
// Operation is CPU only
if (!ndArray.getDevice().equals(Device.cpu())) {
ndArray = ndArray.toDevice(Device.cpu(), false);
}
return PyTorchLibrary.LIB.torchDirectByteBuffer(ndArray.getHandle())
.order(ByteOrder.nativeOrder());
} Moreover, I am not aware of whether there is an API that can set the byte order at the time of |
65c1d59
to
e331259
Compare
I tried to fix integration test failure, but I cannot make it. I don't think this solution can really work. When GC kick in and deleted |
d726d1e
to
0480cb6
Compare
You're right, there was indeed a problem with my previous approach. JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDirectByteBuffer(JNIEnv* env, jobject jthis, jlong jhandle)
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
// sparse and mkldnn are required to be converted to dense to access data ptr
auto tensor = (tensor_ptr->is_sparse() || tensor_ptr->is_mkldnn()) ? tensor_ptr->to_dense() : *tensor_ptr;
tensor = (tensor.is_contiguous()) ? tensor : tensor.contiguous();
size_t nbytes = tensor.nbytes();
if (nbytes > 0x7fffffff) {
env->ThrowNew(ENGINE_EXCEPTION_CLASS, "toDirectByteBuffer() is not supported for large tensor");
return nullptr;
}
// Use tensor.data_ptr() to get the data pointer and create a direct ByteBuffer with NewDirectByteBuffer
void* data_ptr = tensor.data_ptr();
jobject directBuffer = env->NewDirectByteBuffer(data_ptr, nbytes);
return directBuffer;
API_END_RETURN() Previously, I overlooked the fact that Now, I have made some modifications. Here's the updated code snippet: public double[] toDoubleArray() {
if (getDataType() != DataType.FLOAT64) {
throw new IllegalStateException(
"DataType mismatch, Required double" + " Actual " + getDataType());
}
if (isSparse() || JniUtils.getLayout(this) == 2 || !isContiguous()) {
try (final PtNDArray ptNDArray = toContiguous()) {
return toDoubleArray(ptNDArray);
}
} else {
return toDoubleArray(this);
}
} We need to determine whether the tensor is contiguous. |
* | ||
* @return A new {@code PtNDArray} that is guaranteed to be contiguous. | ||
*/ | ||
public PtNDArray toContiguous() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't feel we need to expose isContiguous() and toContiguous() in Java level. We should be able to handle this at torchGetDirectByteBuffer()
level.
@@ -1680,6 +1680,15 @@ public static ByteBuffer getByteBuffer(PtNDArray ndArray) { | |||
.order(ByteOrder.nativeOrder()); | |||
} | |||
|
|||
public static ByteBuffer getDirectByteBuffer(PtNDArray ndArray) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still think we should return DirectBuffer in getByteBuffer()
, I agree it's a risk that potential share the native memory between two tensors. But we can mitigate the risk by overridePtNDArray.copyTo()
function, and copy the buffer if it's a DirectBuffer
Your modifications still have certain issues.
you use the following method to copy non-contiguous memory to a new byte array, and then let ByteBuffer hold the direct memory address of
In this case, it will look for the PyTorch engine and execute
it is determined whether the Buffer is direct memory. If ONNX also returns direct memory through |
You are right, my changes will cause memory leak. Will revert my part. |
Codecov ReportAttention: Patch coverage is
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. Additional details and impacted files@@ Coverage Diff @@
## master #3137 +/- ##
============================================
- Coverage 71.03% 68.43% -2.60%
+ Complexity 7199 7031 -168
============================================
Files 694 697 +3
Lines 32614 32765 +151
Branches 3374 3409 +35
============================================
- Hits 23166 22423 -743
- Misses 7842 8732 +890
- Partials 1606 1610 +4 ☔ View full report in Codecov by Sentry. |
Description
Now the
public double[] toDoubleArray()
method in pytorch Ndarray needs to copy the contents of the local memory to the heap memory, and then copy it from the heap memory to the array object. There is additional copy overhead in this process, and data can be copied directly from local memory to the object.