Skip to content

Commit

Permalink
[pytorch] optimize memory copy cost for pytorch NDArray (#3137)
Browse files Browse the repository at this point in the history
* [pytorch] optimize memory copy cost for pytorch NDArray

---------

Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
  • Loading branch information
ewan0x79 and frankfliu authored May 13, 2024
1 parent abae500 commit 2d9c84b
Show file tree
Hide file tree
Showing 15 changed files with 111 additions and 29 deletions.
39 changes: 29 additions & 10 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,25 @@ default long size() {
return getShape().size();
}

/** {@inheritDoc} */
@Override
default ByteBuffer toByteBuffer() {
return toByteBuffer(false);
}

/**
* Returns the {@code ByteBuffer} presentation of the object.
*
* <p>If returned ByteBuffer is a DirectByteBuffer, it shared the same native memory as the
* NDArray. The native memory will be deleted when NDArray is closed.
*
* <p>Not all the engine support return DirectByteBuffer.
*
* @param tryDirect use DirectBuffer if possible
* @return the {@code ByteBuffer} presentation of the object
*/
ByteBuffer toByteBuffer(boolean tryDirect);

/**
* Converts this {@code NDArray} to a double array.
*
Expand All @@ -236,7 +255,7 @@ default double[] toDoubleArray() {
throw new IllegalStateException(
"DataType mismatch, Required double" + " Actual " + getDataType());
}
DoubleBuffer db = toByteBuffer().asDoubleBuffer();
DoubleBuffer db = toByteBuffer(true).asDoubleBuffer();
double[] ret = new double[db.remaining()];
db.get(ret);
return ret;
Expand All @@ -255,7 +274,7 @@ default float[] toFloatArray() {
throw new IllegalStateException(
"DataType mismatch, Required float, Actual " + getDataType());
}
FloatBuffer fb = toByteBuffer().asFloatBuffer();
FloatBuffer fb = toByteBuffer(true).asFloatBuffer();
float[] ret = new float[fb.remaining()];
fb.get(ret);
return ret;
Expand All @@ -272,7 +291,7 @@ default short[] toShortArray() {
throw new IllegalStateException(
"DataType mismatch, Required int" + " Actual " + getDataType());
}
ShortBuffer ib = toByteBuffer().asShortBuffer();
ShortBuffer ib = toByteBuffer(true).asShortBuffer();
short[] ret = new short[ib.remaining()];
ib.get(ret);
return ret;
Expand All @@ -289,7 +308,7 @@ default int[] toUnsignedShortArray() {
throw new IllegalStateException(
"DataType mismatch, Required int" + " Actual " + getDataType());
}
ShortBuffer ib = toByteBuffer().asShortBuffer();
ShortBuffer ib = toByteBuffer(true).asShortBuffer();
int[] ret = new int[ib.remaining()];
for (int i = 0; i < ret.length; ++i) {
ret[i] = ib.get() & 0xffff;
Expand All @@ -309,7 +328,7 @@ default int[] toIntArray() {
throw new IllegalStateException(
"DataType mismatch, Required int" + " Actual " + getDataType());
}
IntBuffer ib = toByteBuffer().asIntBuffer();
IntBuffer ib = toByteBuffer(true).asIntBuffer();
int[] ret = new int[ib.remaining()];
ib.get(ret);
return ret;
Expand All @@ -326,7 +345,7 @@ default long[] toUnsignedIntArray() {
throw new IllegalStateException(
"DataType mismatch, Required int" + " Actual " + getDataType());
}
IntBuffer ib = toByteBuffer().asIntBuffer();
IntBuffer ib = toByteBuffer(true).asIntBuffer();
long[] ret = new long[ib.remaining()];
for (int i = 0; i < ret.length; ++i) {
ret[i] = ib.get() & 0X00000000FFFFFFFFL;
Expand All @@ -345,7 +364,7 @@ default long[] toLongArray() {
throw new IllegalStateException(
"DataType mismatch, Required long" + " Actual " + getDataType());
}
LongBuffer lb = toByteBuffer().asLongBuffer();
LongBuffer lb = toByteBuffer(true).asLongBuffer();
long[] ret = new long[lb.remaining()];
lb.get(ret);
return ret;
Expand All @@ -358,7 +377,7 @@ default long[] toLongArray() {
* @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
*/
default byte[] toByteArray() {
ByteBuffer bb = toByteBuffer();
ByteBuffer bb = toByteBuffer(true);
if (bb.hasArray() && bb.remaining() == bb.array().length) {
return bb.array();
}
Expand All @@ -374,7 +393,7 @@ default byte[] toByteArray() {
* @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
*/
default int[] toUint8Array() {
ByteBuffer bb = toByteBuffer();
ByteBuffer bb = toByteBuffer(true);
int[] buf = new int[bb.remaining()];
for (int i = 0; i < buf.length; ++i) {
buf[i] = bb.get() & 0xff;
Expand All @@ -393,7 +412,7 @@ default boolean[] toBooleanArray() {
throw new IllegalStateException(
"DataType mismatch, Required boolean" + " Actual " + getDataType());
}
ByteBuffer bb = toByteBuffer();
ByteBuffer bb = toByteBuffer(true);
boolean[] ret = new boolean[bb.remaining()];
for (int i = 0; i < ret.length; ++i) {
ret[i] = bb.get() != 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,17 @@ public Object getObject() {

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
public ByteBuffer toByteBuffer(boolean tryDirect) {
if (object instanceof ByteBuffer) {
return (ByteBuffer) object;
}
throw new UnsupportedOperationException("Operation not supported for FastText");
throw new UnsupportedOperationException("Operation not supported for PassthroughNDArray");
}

/** {@inheritDoc} */
@Override
public void intern(NDArray replaced) {
throw new UnsupportedOperationException("Operation not supported for FastText");
throw new UnsupportedOperationException("Operation not supported for PassthroughNDArray");
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ public void detach() {

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
public ByteBuffer toByteBuffer(boolean tryDirect) {
throw new UnsupportedOperationException("Not supported by the LgbmDataset yet");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ public SparseFormat getSparseFormat() {

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
public ByteBuffer toByteBuffer(boolean tryDirect) {
if (data == null) {
throw new UnsupportedOperationException("Cannot obtain value from DMatrix");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public SparseFormat getSparseFormat() {

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
public ByteBuffer toByteBuffer(boolean tryDirect) {
if (data == null) {
throw new UnsupportedOperationException("Cannot obtain value from DMatrix");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ public String[] toStringArray(Charset charset) {

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
public ByteBuffer toByteBuffer(boolean tryDirect) {
if (getSparseFormat() != SparseFormat.DENSE) {
throw new IllegalStateException("Require Dense NDArray, actual " + getSparseFormat());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public String[] toStringArray(Charset charset) {

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
public ByteBuffer toByteBuffer(boolean tryDirect) {
if (getDataType() == DataType.STRING) {
throw new IllegalArgumentException("Please use toStringArray() for String NDArray.");
}
Expand Down
6 changes: 1 addition & 5 deletions engines/pytorch/pytorch-engine/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,4 @@ publishing {
}
}

clean.doFirst {
delete fileTree(System.getProperty("user.home") + "/.djl.ai/pytorch/") {
include '**/*djl_torch.*'
}
}
clean.dependsOn ":engines:pytorch:pytorch-jni:clean"
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,12 @@ public NDArray stopGradient() {

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
public ByteBuffer toByteBuffer(boolean tryDirect) {
if (getDataType() == DataType.STRING) {
throw new UnsupportedOperationException(
"toByteBuffer is not supported for String tensor.");
}
return JniUtils.getByteBuffer(this);
return JniUtils.getByteBuffer(this, tryDirect);
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1671,11 +1671,25 @@ public static Shape getShape(PtNDArray ndArray) {
return new Shape(PyTorchLibrary.LIB.torchSizes(ndArray.getHandle()));
}

public static ByteBuffer getByteBuffer(PtNDArray ndArray) {
public static ByteBuffer getByteBuffer(PtNDArray ndArray, boolean tryDirect) {
// Operation is CPU only
if (!ndArray.getDevice().equals(Device.cpu())) {
ndArray = ndArray.toDevice(Device.cpu(), false);
}
if (tryDirect) {
if (ndArray.isSparse()
|| getLayout(ndArray) == 2
|| !PyTorchLibrary.LIB.torchIsContiguous(ndArray.getHandle())) {
// keep the same lifecycle as origin NDArray
ndArray =
new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchToContiguous(ndArray.getHandle()));
}
return PyTorchLibrary.LIB
.torchDirectByteBuffer(ndArray.getHandle())
.order(ByteOrder.nativeOrder());
}
return ByteBuffer.wrap(PyTorchLibrary.LIB.torchDataPtr(ndArray.getHandle()))
.order(ByteOrder.nativeOrder());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ private PyTorchLibrary() {}

native byte[] torchDataPtr(long handle);

native ByteBuffer torchDirectByteBuffer(long handle);

native boolean torchIsContiguous(long handle);

native long torchToContiguous(long handle);

native int torchDType(long handle);

native int[] torchDevice(long handle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,53 @@ JNIEXPORT jbyteArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDataPtr
API_END_RETURN()
}

JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDirectByteBuffer(
JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
// Check if the tensor is sparse or optimized by MKL-DNN, if so, throw an exception
if (tensor_ptr->is_sparse() || tensor_ptr->is_mkldnn()) {
env->ThrowNew(ENGINE_EXCEPTION_CLASS, "torchDirectByteBuffer() is not supported for sparse or MKL-DNN tensors");
return nullptr;
}
// Check if the tensor is contiguous, if not, throw an exception
if (!tensor_ptr->is_contiguous()) {
env->ThrowNew(ENGINE_EXCEPTION_CLASS, "torchDirectByteBuffer() requires the tensor to be contiguous");
return nullptr;
}
size_t nbytes = tensor_ptr->nbytes();
// todo indeed, we can remove it in future!
if (nbytes > 0x7fffffff) {
env->ThrowNew(ENGINE_EXCEPTION_CLASS, "torchDirectByteBuffer() is not supported for large tensor");
return nullptr;
}
// Use tensor.data_ptr() to obtain the data pointer, and create a direct ByteBuffer using NewDirectByteBuffer
void* data_ptr = tensor_ptr->data_ptr();
return env->NewDirectByteBuffer(data_ptr, nbytes);
API_END_RETURN()
}

JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIsContiguous(
JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
bool is_contiguous = tensor_ptr->is_contiguous();
return static_cast<jboolean>(is_contiguous);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchToContiguous(
JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
// sparse and mkldnn are required to be converted to dense to access data ptr
auto tensor = (tensor_ptr->is_sparse() || tensor_ptr->is_mkldnn()) ? tensor_ptr->to_dense() : *tensor_ptr;
tensor = (tensor.is_contiguous()) ? tensor : tensor.contiguous();
const auto* result_ptr = new torch::Tensor(tensor);
return reinterpret_cast<jlong>(result_ptr);
API_END_RETURN()
}

JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDeleteTensor(
JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ public String[] toStringArray(Charset charset) {

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
public ByteBuffer toByteBuffer(boolean tryDirect) {
if (getDataType() == DataType.STRING) {
throw new IllegalArgumentException("Please use toStringArray() for String NDArray.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public void detach() {

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
public ByteBuffer toByteBuffer(boolean tryDirect) {
data.rewind();
return data;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ public NDArray stopGradient() {

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
public ByteBuffer toByteBuffer(boolean tryDirect) {
ByteBuffer bb = RustLibrary.getByteBuffer(getHandle());
bb.order(ByteOrder.nativeOrder());
return bb;
Expand Down

0 comments on commit 2d9c84b

Please sign in to comment.