Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jun 19, 2024
1 parent 541ac79 commit 14b9305
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ml.dmlc.xgboost4j.scala.spark.GPUXGBoostPlugin
ml.dmlc.xgboost4j.scala.spark.GpuXGBoostPlugin
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,20 @@ import scala.jdk.CollectionConverters.seqAsJavaListConverter

import ai.rapids.cudf.Table
import com.nvidia.spark.rapids.ColumnarRdd
import org.apache.spark.ml.param.Param
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Column, Dataset}
import org.apache.spark.sql.functions.col

import ml.dmlc.xgboost4j.java.{CudfColumnBatch, GpuColumnBatch}
import ml.dmlc.xgboost4j.scala.QuantileDMatrix
import ml.dmlc.xgboost4j.scala.spark.params.HasGroupCol

class GPUXGBoostPlugin extends XGBoostPlugin {
/**
* GpuXGBoostPlugin is the XGBoost plugin which leverage spark-rapids
* to accelerate the XGBoost from ETL to train.
*/
class GpuXGBoostPlugin extends XGBoostPlugin {

/**
* Whether the plugin is enabled or not, if not enabled, fallback
Expand All @@ -48,17 +55,29 @@ class GPUXGBoostPlugin extends XGBoostPlugin {
private def preprocess[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]](
estimator: XGBoostEstimator[T, M], dataset: Dataset[_]): Dataset[_] = {

// Columns to be selected for XGBoost training
val selectedCols: ArrayBuffer[Column] = ArrayBuffer.empty
val schema = dataset.schema

val features = estimator.getFeaturesCols
def selectCol(c: Param[String]) = {
// TODO support numeric types
if (estimator.isDefinedNonEmpty(c)) {
selectedCols.append(estimator.castToFloatIfNeeded(schema, estimator.getOrDefault(c)))
}
}

(features.toSeq ++ Seq(estimator.getLabelCol)).foreach { name =>
Seq(estimator.labelCol, estimator.weightCol, estimator.baseMarginCol).foreach(selectCol)
estimator match {
case p: HasGroupCol => selectCol(p.groupCol)
case _ =>
}

// TODO support array/vector feature
estimator.getFeaturesCols.foreach { name =>
val col = estimator.castToFloatIfNeeded(dataset.schema, name)
selectedCols.append(col)
}

val input = dataset.select(selectedCols: _*)

estimator.repartitionIfNeeded(input)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ private[spark] abstract class XGBoostEstimator[

val serviceLoader = ServiceLoader.load(classOf[XGBoostPlugin], classLoader)

// For now, we only trust GPUXGBoostPlugin.
// For now, we only trust GpuXGBoostPlugin.
serviceLoader.asScala.filter(x => x.getClass.getName.equals(
"ml.dmlc.xgboost4j.scala.spark.GPUXGBoostPlugin")).toList match {
"ml.dmlc.xgboost4j.scala.spark.GpuXGBoostPlugin")).toList match {
case Nil => None
case head :: Nil =>
Some(head)
Expand Down Expand Up @@ -128,7 +128,6 @@ private[spark] abstract class XGBoostEstimator[
* Build the columns indices.
*/
private[spark] def buildColumnIndices(schema: StructType): ColumnIndices = {

// Get feature id(s)
val (featureIds: Option[Seq[Int]], featureId: Option[Int]) =
if (getFeaturesCols.length != 0) {
Expand Down Expand Up @@ -161,6 +160,10 @@ private[spark] abstract class XGBoostEstimator[
groupId)
}

private[spark] def isDefinedNonEmpty(param: Param[String]): Boolean = {
if (isDefined(param) && $(param).nonEmpty) true else false
}

/**
* Preprocess the dataset to meet the xgboost input requirement
*
Expand All @@ -174,7 +177,7 @@ private[spark] abstract class XGBoostEstimator[
val schema = dataset.schema

def selectCol(c: Param[String]) = {
if (isDefined(c) && $(c).nonEmpty) {
if (isDefinedNonEmpty(c)) {
// Validation col should be a boolean column.
if (c == featuresCol) {
selectedCols.append(col($(c)))
Expand Down

0 comments on commit 14b9305

Please sign in to comment.