Skip to content

Commit

Permalink
Rewrite cudf column and cudf column batch. (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 authored Jun 26, 2024
1 parent 779fde8 commit efd6f42
Show file tree
Hide file tree
Showing 30 changed files with 281 additions and 465 deletions.
2 changes: 1 addition & 1 deletion jvm-packages/checkstyle.xml
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@
<property name="braceAdjustment" value="0"/>
<property name="caseIndent" value="2"/>
<property name="throwsIndent" value="4"/>
<property name="lineWrappingIndentation" value="4"/>
<property name="lineWrappingIndentation" value="2"/>
<property name="arrayInitIndent" value="2"/>
</module>
<module name="ImportOrder">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,9 @@ object SparkMLlibPipeline {
"max_depth" -> 2,
"objective" -> "multi:softprob",
"num_class" -> 3,
"num_round" -> 100,
"num_workers" -> numWorkers,
"device" -> device
)
)
).setNumRound(10).setNumWorkers(numWorkers)
booster.setFeaturesCol("features")
booster.setLabelCol("classIndex")
val labelConverter = new IndexToString()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ private[spark] def run(spark: SparkSession, inputPath: String,
"max_depth" -> 2,
"objective" -> "multi:softprob",
"num_class" -> 3,
"num_round" -> 100,
"num_workers" -> numWorkers,
"device" -> device,
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
"device" -> device)
val xgbClassifier = new XGBoostClassifier(xgbParam).
setFeaturesCol("features").
setLabelCol("classIndex")
.setNumWorkers(numWorkers)
.setNumRound(10)
.setEvalDataset(eval1)
val xgbClassificationModel = xgbClassifier.fit(train)
xgbClassificationModel.transform(test)
}
Expand Down
6 changes: 6 additions & 0 deletions jvm-packages/xgboost4j-spark-gpu/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@
<version>${spark.rapids.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${fasterxml.jackson.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,93 +16,102 @@

package ml.dmlc.xgboost4j.java;

import java.util.ArrayList;
import java.util.List;

import ai.rapids.cudf.BaseDeviceMemoryBuffer;
import ai.rapids.cudf.BufferType;
import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.DType;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;

/**
* This class is composing of base data with Apache Arrow format from Cudf ColumnVector.
* It will be used to generate the cuda array interface.
* CudfColumn is the CUDF column representing, providing the cuda array interface
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
public class CudfColumn extends Column {
private List<Long> shape = new ArrayList<>(); // row count
private List<Object> data = new ArrayList<>(); // gpu data buffer address
private String typestr;
private int version = 1;
private CudfColumn mask = null;

public CudfColumn(long shape, long data, String typestr, int version) {
this.shape.add(shape);
this.data.add(data);
this.data.add(false);
this.typestr = typestr;
this.version = version;
}

private final long dataPtr; // gpu data buffer address
private final long shape; // row count
private final long validPtr; // gpu valid buffer address
private final int typeSize; // type size in bytes
private final String typeStr; // follow array interface spec
private final long nullCount; // null count

private String arrayInterface = null; // the cuda array interface

/**
* Create CudfColumn according to ColumnVector
*/
public static CudfColumn from(ColumnVector cv) {
BaseDeviceMemoryBuffer dataBuffer = cv.getDeviceBufferFor(BufferType.DATA);
BaseDeviceMemoryBuffer validBuffer = cv.getDeviceBufferFor(BufferType.VALIDITY);
long validPtr = 0;
if (validBuffer != null) {
validPtr = validBuffer.getAddress();
}
BaseDeviceMemoryBuffer dataBuffer = cv.getData();
assert dataBuffer != null;

DType dType = cv.getType();
String typeStr = "";
if (dType == DType.FLOAT32 || dType == DType.FLOAT64 ||
dType == DType.TIMESTAMP_DAYS || dType == DType.TIMESTAMP_MICROSECONDS ||
dType == DType.TIMESTAMP_MILLISECONDS || dType == DType.TIMESTAMP_NANOSECONDS ||
dType == DType.TIMESTAMP_SECONDS) {
dType == DType.TIMESTAMP_DAYS || dType == DType.TIMESTAMP_MICROSECONDS ||
dType == DType.TIMESTAMP_MILLISECONDS || dType == DType.TIMESTAMP_NANOSECONDS ||
dType == DType.TIMESTAMP_SECONDS) {
typeStr = "<f" + dType.getSizeInBytes();
} else if (dType == DType.BOOL8 || dType == DType.INT8 || dType == DType.INT16 ||
dType == DType.INT32 || dType == DType.INT64) {
dType == DType.INT32 || dType == DType.INT64) {
typeStr = "<i" + dType.getSizeInBytes();
} else {
// Unsupported type.
throw new IllegalArgumentException("Unsupported data type: " + dType);
}

return new CudfColumn(dataBuffer.getAddress(), cv.getRowCount(), validPtr,
dType.getSizeInBytes(), typeStr, cv.getNullCount());
}
CudfColumn data = new CudfColumn(cv.getRowCount(), dataBuffer.getAddress(), typeStr, 1);

private CudfColumn(long dataPtr, long shape, long validPtr, int typeSize, String typeStr,
long nullCount) {
this.dataPtr = dataPtr;
this.shape = shape;
this.validPtr = validPtr;
this.typeSize = typeSize;
this.typeStr = typeStr;
this.nullCount = nullCount;
BaseDeviceMemoryBuffer validBuffer = cv.getValid();
if (validBuffer != null && cv.getNullCount() != 0) {
CudfColumn mask = new CudfColumn(cv.getRowCount(), validBuffer.getAddress(), "<t1", 1);
data.setMask(mask);
}
return data;
}

@Override
public String getArrayInterfaceJson() {
// There is no race-condition
if (arrayInterface == null) {
arrayInterface = CudfUtils.buildArrayInterface(this);
}
return arrayInterface;
public List<Long> getShape() {
return shape;
}

public long getDataPtr() {
return dataPtr;
public List<Object> getData() {
return data;
}

public long getShape() {
return shape;
public String getTypestr() {
return typestr;
}

public long getValidPtr() {
return validPtr;
public int getVersion() {
return version;
}

public int getTypeSize() {
return typeSize;
public CudfColumn getMask() {
return mask;
}

public String getTypeStr() {
return typeStr;
public void setMask(CudfColumn mask) {
this.mask = mask;
}

public long getNullCount() {
return nullCount;
@Override
public String toJson() {
ObjectMapper mapper = new ObjectMapper();
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
try {
List<CudfColumn> objects = new ArrayList<>(1);
objects.add(this);
return mapper.writeValueAsString(objects);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,71 +16,108 @@

package ml.dmlc.xgboost4j.java;

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import ai.rapids.cudf.Table;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;

/**
* Class to wrap CUDF Table to generate the cuda array interface.
* CudfColumnBatch wraps multiple CudfColumns to provide the cuda
* array interface json string for all columns.
*/
public class CudfColumnBatch extends ColumnBatch {
private final Table feature;
private final Table label;
private final Table weight;
private final Table baseMargin;

public CudfColumnBatch(Table feature, Table labels, Table weights, Table baseMargins) {
this.feature = feature;
this.label = labels;
this.weight = weights;
this.baseMargin = baseMargins;
@JsonIgnore
private final Table featureTable;
@JsonIgnore
private final Table labelTable;
@JsonIgnore
private final Table weightTable;
@JsonIgnore
private final Table baseMarginTable;

private List<CudfColumn> features;
private List<CudfColumn> label;
private List<CudfColumn> weight;
private List<CudfColumn> baseMargin;

public CudfColumnBatch(Table featureTable, Table labelTable, Table weightTable,
Table baseMarginTable) {
this.featureTable = featureTable;
this.labelTable = labelTable;
this.weightTable = weightTable;
this.baseMarginTable = baseMarginTable;

features = initializeCudfColumns(featureTable);
if (labelTable != null) {
assert labelTable.getNumberOfColumns() == 1;
label = initializeCudfColumns(labelTable);
}

if (weightTable != null) {
assert weightTable.getNumberOfColumns() == 1;
weight = initializeCudfColumns(weightTable);
}

if (baseMarginTable != null) {
baseMargin = initializeCudfColumns(baseMarginTable);
}
}

@Override
public String getFeatureArrayInterface() {
return getArrayInterface(this.feature);
private List<CudfColumn> initializeCudfColumns(Table table) {
assert table != null && table.getNumberOfColumns() > 0;

return IntStream.range(0, table.getNumberOfColumns())
.mapToObj(table::getColumn)
.map(CudfColumn::from)
.collect(Collectors.toList());
}

@Override
public String getLabelsArrayInterface() {
return getArrayInterface(this.label);
public List<CudfColumn> getFeatures() {
return features;
}

@Override
public String getWeightsArrayInterface() {
return getArrayInterface(this.weight);
public List<CudfColumn> getLabel() {
return label;
}

@Override
public String getBaseMarginsArrayInterface() {
return getArrayInterface(this.baseMargin);
public List<CudfColumn> getWeight() {
return weight;
}

@Override
public void close() {
if (feature != null) feature.close();
if (label != null) label.close();
if (weight != null) weight.close();
if (baseMargin != null) baseMargin.close();
public List<CudfColumn> getBaseMargin() {
return baseMargin;
}

private String getArrayInterface(Table table) {
if (table == null || table.getNumberOfColumns() == 0) {
return "";
public String toJson() {
ObjectMapper mapper = new ObjectMapper();
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
try {
return mapper.writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
return CudfUtils.buildArrayInterface(getAsCudfColumn(table));
}

private CudfColumn[] getAsCudfColumn(Table table) {
if (table == null || table.getNumberOfColumns() == 0) {
// This will never happen.
return new CudfColumn[]{};
@Override
public String toFeaturesJson() {
ObjectMapper mapper = new ObjectMapper();
try {
return mapper.writeValueAsString(features);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}

return IntStream.range(0, table.getNumberOfColumns())
.mapToObj((i) -> table.getColumn(i))
.map(CudfColumn::from)
.toArray(CudfColumn[]::new);
}

@Override
public void close() {
if (featureTable != null) featureTable.close();
if (labelTable != null) labelTable.close();
if (weightTable != null) weightTable.close();
if (baseMarginTable != null) baseMarginTable.close();
}
}
Loading

0 comments on commit efd6f42

Please sign in to comment.