Skip to content

Commit

Permalink
supporting array
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jun 25, 2024
1 parent 779fde8 commit ff60e14
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@

package ml.dmlc.xgboost4j.java;

import ai.rapids.cudf.BaseDeviceMemoryBuffer;
import ai.rapids.cudf.BufferType;
import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.DType;
import ai.rapids.cudf.*;

/**
* This class is composing of base data with Apache Arrow format from Cudf ColumnVector.
Expand All @@ -30,46 +27,63 @@ public class CudfColumn extends Column {
private final long dataPtr; // gpu data buffer address
private final long shape; // row count
private final long validPtr; // gpu valid buffer address
private final long offsetPtr; // gpu offset address for list
private final int typeSize; // type size in bytes
private final String typeStr; // follow array interface spec
private final long nullCount; // null count

private String arrayInterface = null; // the cuda array interface

public static CudfColumn from(ColumnVector cv) {
BaseDeviceMemoryBuffer dataBuffer = cv.getDeviceBufferFor(BufferType.DATA);
BaseDeviceMemoryBuffer validBuffer = cv.getDeviceBufferFor(BufferType.VALIDITY);
public static CudfColumn from(ColumnVector column) {
ColumnView cv = column;
DType dType = cv.getType();
long nullCount = cv.getNullCount();
String floatType = "<f";
String integerType = "<i";
long offsetPtr = 0;
if (dType == DType.LIST) {
floatType = "<lf";
integerType = "<li";
if (cv.getOffsets() != null) {
offsetPtr = cv.getOffsets().getAddress();
}
cv = cv.getChildColumnView(0);
dType = cv.getType();
}

BaseDeviceMemoryBuffer dataBuffer = cv.getData();
BaseDeviceMemoryBuffer validBuffer = cv.getValid();
long validPtr = 0;
if (validBuffer != null) {
validPtr = validBuffer.getAddress();
}
DType dType = cv.getType();
String typeStr = "";
if (dType == DType.FLOAT32 || dType == DType.FLOAT64 ||
dType == DType.TIMESTAMP_DAYS || dType == DType.TIMESTAMP_MICROSECONDS ||
dType == DType.TIMESTAMP_MILLISECONDS || dType == DType.TIMESTAMP_NANOSECONDS ||
dType == DType.TIMESTAMP_SECONDS) {
typeStr = "<f" + dType.getSizeInBytes();
dType == DType.TIMESTAMP_DAYS || dType == DType.TIMESTAMP_MICROSECONDS ||
dType == DType.TIMESTAMP_MILLISECONDS || dType == DType.TIMESTAMP_NANOSECONDS ||
dType == DType.TIMESTAMP_SECONDS) {
typeStr = floatType + dType.getSizeInBytes();
} else if (dType == DType.BOOL8 || dType == DType.INT8 || dType == DType.INT16 ||
dType == DType.INT32 || dType == DType.INT64) {
typeStr = "<i" + dType.getSizeInBytes();
dType == DType.INT32 || dType == DType.INT64) {
typeStr = integerType + dType.getSizeInBytes();
} else {
// Unsupported type.
throw new IllegalArgumentException("Unsupported data type: " + dType);
}

return new CudfColumn(dataBuffer.getAddress(), cv.getRowCount(), validPtr,
dType.getSizeInBytes(), typeStr, cv.getNullCount());
dType.getSizeInBytes(), typeStr, nullCount, offsetPtr);
}

private CudfColumn(long dataPtr, long shape, long validPtr, int typeSize, String typeStr,
long nullCount) {
long nullCount, long offsetPtr) {
this.dataPtr = dataPtr;
this.shape = shape;
this.validPtr = validPtr;
this.typeSize = typeSize;
this.typeStr = typeStr;
this.nullCount = nullCount;
this.offsetPtr = offsetPtr;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

package ml.dmlc.xgboost4j.java;

import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.ColumnView;
import ai.rapids.cudf.HostColumnVector;
import ai.rapids.cudf.Table;
import junit.framework.TestCase;
import org.junit.Test;
Expand Down Expand Up @@ -135,4 +138,56 @@ private float[] convertFloatTofloat(Float[]... datas) {
}
return floatArray;
}

@Test
public void testMakingDMatrixViaArray() {
// ColumnVector child1 = ColumnVector.fromFloats(1, 2, 3, 4, 5, 6);
// ColumnVector child2 = ColumnVector.fromFloats(11, 12, 13, 14, 15, 16);
// ColumnVector list = ColumnVector.makeList(child1, child2);
// child2.close();
// child1.close();

Float[][] features1 = {
{1.0f, 12.0f},
{2.0f, 13.0f},
null,
{4.0f, null},
{5.0f, 16.0f}
};

Float[] label1 = {0.0f, 1.0f, 0.0f, 1.0f, 0.0f};

Table X1 = new Table.TestBuilder().column(features1).build();
Table y1 = new Table.TestBuilder().column(label1).build();

// HostColumnVector hcv = X1.getColumn(0).copyToHost();

//
ColumnVector t = X1.getColumn(0);
ColumnView cv = t.getChildColumnView(0);
//
System.out.println("----");

Float[][] features2 = {
{6.0f, 17.0f},
{7.0f, 18.0f},
};
Float[] label2 = {0.0f, 1.0f, 0.0f, 1.0f, 0.0f};
Table X2 = new Table.TestBuilder().column(features2).build();
Table y2 = new Table.TestBuilder().column(label2).build();

List<ColumnBatch> tables = new LinkedList<>();
tables.add(new CudfColumnBatch(X1, y1, null, null));
tables.add(new CudfColumnBatch(X2, y2, null, null));

try {
DMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 8, 1);
} catch (XGBoostError e) {
throw new RuntimeException(e);
}

System.out.println("--------------");


}
}

0 comments on commit ff60e14

Please sign in to comment.