Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pytorch] optimize memory copy cost for pytorch NDArray #3137

Merged
merged 2 commits into from
May 13, 2024

Conversation

ewan0x79
Copy link
Contributor

Description

Now the public double[] toDoubleArray() method in pytorch Ndarray needs to copy the contents of the local memory to the heap memory, and then copy it from the heap memory to the array object. There is additional copy overhead in this process, and data can be copied directly from local memory to the object.

@@ -50,6 +50,8 @@ private PyTorchLibrary() {}

native byte[] torchDataPtr(long handle);

native ByteBuffer torchDirectByteBuffer(long handle);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should just replace torchDataPtr() with your new method.

@@ -1680,6 +1680,15 @@ public static ByteBuffer getByteBuffer(PtNDArray ndArray) {
.order(ByteOrder.nativeOrder());
}

public static ByteBuffer getDirectByteBuffer(PtNDArray ndArray) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace getByteBuffer() with your implementation

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think we should return DirectBuffer in getByteBuffer(), I agree it's a risk that potential share the native memory between two tensors. But we can mitigate the risk by overridePtNDArray.copyTo() function, and copy the buffer if it's a DirectBuffer

@@ -45,4 +45,126 @@ public void testLargeTensor() {
Assert.assertThrows(EngineException.class, array::toByteArray);
}
}

@Test
public static void testPtTensorToLongArray() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the tests have been covered in integration test, we don't need them here

@ewan0x79
Copy link
Contributor Author

ewan0x79 commented Apr 30, 2024

Thank you for your response. What you mean is to set ai.djl.pytorch.engine.PtNDArray#toByteBuffer to return a DirectBuffer. I had considered this idea initially, but the current framework defaults to it being a non-DirectBuffer, and directly modifying it would cause issues. For example, in the function ai.djl.ndarray.NDArray#copyTo,

    default void copyTo(NDArray array) {
        array.set(toByteBuffer());
    }

which calls ai.djl.pytorch.engine.PtNDArray#toByteBuffer to convert to Buffer, it defaults to this being non-direct memory. Therefore, in ai.djl.pytorch.engine.PtNDArray#set:

@Override
public void set(Buffer buffer) {
    int size = Math.toIntExact(size());
    DataType type = getDataType();
    BaseNDManager.validateBuffer(buffer, type, size);
    // TODO how do we handle the exception happened in the middle
    dataRef = null;
    if (buffer.isDirect() && buffer instanceof ByteBuffer) {
        // If NDArray is on the GPU, it is native code responsibility to control the data life cycle
        if (!getDevice().isGpu()) {
            dataRef = (ByteBuffer) buffer;
        }
        JniUtils.set(this, (ByteBuffer) buffer);
        return;
    }
    // int8, uint8, boolean use ByteBuffer, so need to explicitly input DataType
    ByteBuffer buf = manager.allocateDirect(size * type.getNumOfBytes());
    BaseNDManager.copyBuffer(buffer, buf);
    // If NDArray is on the GPU, it is native code responsibility to control the data life cycle
    if (!getDevice().isGpu()) {
        dataRef = buf;
    }
    JniUtils.set(this, buf);
}

By determining it is not direct memory, it can create new direct memory and perform data copying, thus achieving the purpose of deep copying. If ai.djl.pytorch.engine.PtNDArray#toByteBuffer returns direct memory, it would cause them to share the same memory. Ideally, an toDirectByteBuffer should be implemented in ai.djl.ndarray.NDArray, so that directBuffer and nonDirectBuffer can be distinguished. However, this would involve modifications to multiple engines, such as the Onnxruntime engine. Although Onnxruntime's underlying layer provides ai.onnxruntime.OnnxTensor#getBuffer() for obtaining direct memory, this is a private method and cannot be directly used. Therefore, at this stage, I have only adjusted for Pytorch.

@ewan0x79
Copy link
Contributor Author

Regarding the byte order issue you mentioned, I observed that in the ONNX Runtime engine, the byte order is set at the Java level:

/*
 * Class:     ai_onnxruntime_OnnxTensor
 * Method:    getBuffer
 * Signature: (JJ)Ljava/nio/ByteBuffer;
 */
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxTensor_getBuffer
        (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
    (void) jobj;  // Required JNI parameter not needed by functions that don't need to access their host object.
    const OrtApi* api = (const OrtApi*) apiHandle;
    OrtValue* ortValue = (OrtValue *) handle;
    JavaTensorTypeShape typeShape;
    OrtErrorCode code = getTensorTypeShape(jniEnv, &typeShape, api, ortValue);
    if (code == ORT_OK) {
      size_t typeSize = onnxTypeSize(typeShape.onnxTypeEnum);
      size_t sizeBytes = typeShape.elementCount * typeSize;
      uint8_t* arr = NULL;
      code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
      if (code == ORT_OK) {
        return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, (jlong)sizeBytes);
      }
    }
    return NULL;
}
private ByteBuffer getBuffer() {
    return getBuffer(OnnxRuntime.ortApiHandle, nativeHandle).order(ByteOrder.nativeOrder());
}
private native ByteBuffer getBuffer(long apiHandle, long nativeHandle);

It sets the byte order at the Java level, following this approach, so I also set it at the Java level:

public static ByteBuffer getDirectByteBuffer(PtNDArray ndArray) {
    // Operation is CPU only
    if (!ndArray.getDevice().equals(Device.cpu())) {
        ndArray = ndArray.toDevice(Device.cpu(), false);
    }
    return PyTorchLibrary.LIB.torchDirectByteBuffer(ndArray.getHandle())
            .order(ByteOrder.nativeOrder());
}

Moreover, I am not aware of whether there is an API that can set the byte order at the time of NewDirectByteBuffer.

@frankfliu frankfliu force-pushed the master_track branch 4 times, most recently from 65c1d59 to e331259 Compare May 7, 2024 05:39
@frankfliu
Copy link
Contributor

@ewan0x79

I tried to fix integration test failure, but I cannot make it. I don't think this solution can really work. When GC kick in and deleted DirectBuffer, in certain cases, the native memory get trashed. I can cause multiple data corruption.

@ewan0x79 ewan0x79 force-pushed the master_track branch 3 times, most recently from d726d1e to 0480cb6 Compare May 8, 2024 09:17
@ewan0x79
Copy link
Contributor Author

ewan0x79 commented May 8, 2024

@frankfliu

You're right, there was indeed a problem with my previous approach.

JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDirectByteBuffer(JNIEnv* env, jobject jthis, jlong jhandle)
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
// sparse and mkldnn are required to be converted to dense to access data ptr
auto tensor = (tensor_ptr->is_sparse() || tensor_ptr->is_mkldnn()) ? tensor_ptr->to_dense() : *tensor_ptr;
tensor = (tensor.is_contiguous()) ? tensor : tensor.contiguous();
size_t nbytes = tensor.nbytes();
if (nbytes > 0x7fffffff) {
  env->ThrowNew(ENGINE_EXCEPTION_CLASS, "toDirectByteBuffer() is not supported for large tensor");
  return nullptr;
}
// Use tensor.data_ptr() to get the data pointer and create a direct ByteBuffer with NewDirectByteBuffer
void* data_ptr = tensor.data_ptr();
jobject directBuffer = env->NewDirectByteBuffer(data_ptr, nbytes);
return directBuffer;
API_END_RETURN()

Previously, I overlooked the fact that tensor.contiguous() returns a new tensor.
When a new tensor is returned, the local variable tensor on the stack points to a new tensor in the local memory.
The intention of using jobject directBuffer = env->NewDirectByteBuffer(data_ptr, nbytes);
is to allow ByteBuffer to hold the direct memory address of the tensor data.
However, after the local method ends, the local variables on the stack are recycled,
causing the new tensor in the local memory to be recycled as well (this part is handled by C++),
leading to ByteBuffer pointing to an invalid address. Therefore, calling this method might cause issues.

Now, I have made some modifications. Here's the updated code snippet:

public double[] toDoubleArray() {
    if (getDataType() != DataType.FLOAT64) {
        throw new IllegalStateException(
                "DataType mismatch, Required double" + " Actual " + getDataType());
    }
    if (isSparse() || JniUtils.getLayout(this) == 2 || !isContiguous()) {
        try (final PtNDArray ptNDArray = toContiguous()) {
            return toDoubleArray(ptNDArray);
        }
    } else {
        return toDoubleArray(this);
    }
}

We need to determine whether the tensor is contiguous.
If it is contiguous, we can directly use ByteBuffer to hold the local memory address of the data,
just like before.
If it is not contiguous, we need to construct a new tensor
(moving the original non-contiguous data to contiguous memory).
ByteBuffer can hold the local memory address of this new tensor,
which will be destroyed after the data is copied from local memory to the array.

*
* @return A new {@code PtNDArray} that is guaranteed to be contiguous.
*/
public PtNDArray toContiguous() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel we need to expose isContiguous() and toContiguous() in Java level. We should be able to handle this at torchGetDirectByteBuffer() level.

@@ -1680,6 +1680,15 @@ public static ByteBuffer getByteBuffer(PtNDArray ndArray) {
.order(ByteOrder.nativeOrder());
}

public static ByteBuffer getDirectByteBuffer(PtNDArray ndArray) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think we should return DirectBuffer in getByteBuffer(), I agree it's a risk that potential share the native memory between two tensors. But we can mitigate the risk by overridePtNDArray.copyTo() function, and copy the buffer if it's a DirectBuffer

@ewan0x79
Copy link
Contributor Author

ewan0x79 commented May 9, 2024

@frankfliu

Your modifications still have certain issues.

  • Firstly, in the function JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDirectByteBuffer,
void* data_ptr = tensor.data_ptr();
  if (tensor_ptr->is_sparse() || tensor_ptr->is_mkldnn() || !tensor_ptr->is_contiguous()) {
    // We have to make a copy anyway
    void* data = new char[nbytes];
    data_ptr = std::memcpy(data, data_ptr, nbytes);
  }
  return env->NewDirectByteBuffer(data_ptr, nbytes); 

you use the following method to copy non-contiguous memory to a new byte array, and then let ByteBuffer hold the direct memory address of data. However, the JVM will not release the native memory space held by ByteBuffer when it releases ByteBuffer, which will lead to native memory leakage.

  • Secondly, letting getByteBuffer return direct memory is temporarily not a problem, but if similar modifications are applied to other engines in the future, issues may arise. For example, in the ONNXRUNTIME engine, the method ai.djl.ndarray.NDArrayAdapter#getAlternativeArray assumes the existence of both PyTorch and ONNXruntime engines.
 private NDArray getAlternativeArray() {
        if (alternativeManager == null) {
            throw new UnsupportedOperationException(UNSUPPORTED_MSG);
        }
        if (alternativeArray == null) {
            alternativeArray = alternativeManager.from(this);
        } else {
            alternativeArray.set(getDataType().asDataType(toByteBuffer()));
        }
        return alternativeArray;
    }

In this case, it will look for the PyTorch engine and execute alternativeArray = alternativeManager.from(this);. In PyTorch's ai.djl.pytorch.engine.PtNDManager#from,

   @Override
    public PtNDArray from(NDArray array) {
        if (array == null || array instanceof PtNDArray) {
            return (PtNDArray) array;
        }
        PtNDArray result = create(array.toByteBuffer(), array.getShape(), array.getDataType());
        result.setName(array.getName());
        return result;
    }      
@Override
    public PtNDArray create(Buffer data, Shape shape, DataType dataType) {
        int size = Math.toIntExact(shape.size());
        BaseNDManager.validateBuffer(data, dataType, size);
        if (data.isDirect() && data instanceof ByteBuffer) {
            return JniUtils.createNdFromByteBuffer(
                    this, (ByteBuffer) data, shape, dataType, SparseFormat.DENSE, device);
        }
        ByteBuffer buf = allocateDirect(size * dataType.getNumOfBytes());
        copyBuffer(data, buf);
        return JniUtils.createNdFromByteBuffer(
                this, buf, shape, dataType, SparseFormat.DENSE, device);
    }

it is determined whether the Buffer is direct memory. If ONNX also returns direct memory through getByteBuffer, it might cause two NDARRAYs to share the same memory space. I am not sure if their underlying formats are the same (operations on a tensor might change the underlying data, for example, one is released while the other is not; one changes the underlying data format, but the other does not know). This seems to pose a significant risk. Of course, this could also be implemented by overriding ai.djl.pytorch.engine.PtNDManager#create(java.nio.Buffer, ai.djl.ndarray.types.Shape, ai.djl.ndarray.types.DataType), but it might not be easy.

@frankfliu
Copy link
Contributor

@ewan0x79

You are right, my changes will cause memory leak. Will revert my part.

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 79.16667% with 5 lines in your changes are missing coverage. Please review.

Project coverage is 68.43%. Comparing base (6efe660) to head (9c9592a).
Report is 220 commits behind head on master.

Files Patch % Lines
api/src/main/java/ai/djl/ndarray/NDArray.java 72.72% 3 Missing ⚠️
...ine/src/main/java/ai/djl/pytorch/jni/JniUtils.java 80.00% 0 Missing and 2 partials ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@             Coverage Diff              @@
##             master    #3137      +/-   ##
============================================
- Coverage     71.03%   68.43%   -2.60%     
+ Complexity     7199     7031     -168     
============================================
  Files           694      697       +3     
  Lines         32614    32765     +151     
  Branches       3374     3409      +35     
============================================
- Hits          23166    22423     -743     
- Misses         7842     8732     +890     
- Partials       1606     1610       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@frankfliu frankfliu merged commit 2d9c84b into deepjavalibrary:master May 13, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants