Skip to content

Commit

Permalink
[jvm-packages] fix spark-rapids compatibility issue (#8240)
Browse files Browse the repository at this point in the history
* [jvm-packages] fix spark-rapids compatibility issue

spark-rapids (from 22.10) has shimmed GpuColumnVector, which means
we can't call it directly. So this PR call the UnshimmedGpuColumnVector
  • Loading branch information
wbo4958 authored Sep 22, 2022
1 parent ab342af commit 8d247f0
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@

package ml.dmlc.xgboost4j.scala.rapids.spark

import scala.collection.Iterator
import scala.collection.JavaConverters._

import com.nvidia.spark.rapids.{GpuColumnVector}
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
import ml.dmlc.xgboost4j.java.nvidia.spark.GpuColumnBatch
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, DeviceQuantileDMatrix}
Expand Down Expand Up @@ -331,7 +329,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
} else {
try {
currentBatch = new ColumnarBatch(
GpuColumnVector.extractColumns(table, dataTypes).map(_.copyToHost()),
GpuUtils.extractBatchToHost(table, dataTypes),
table.getRowCount().toInt)
val rowIterator = currentBatch.rowIterator().asScala
.map(toUnsafe)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,32 @@
package ml.dmlc.xgboost4j.scala.rapids.spark

import ai.rapids.cudf.Table
import com.nvidia.spark.rapids.ColumnarRdd
import com.nvidia.spark.rapids.{ColumnarRdd, GpuColumnVector}
import ml.dmlc.xgboost4j.scala.spark.util.Utils

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.ml.param.{Param, Params}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{FloatType, NumericType, StructType}
import org.apache.spark.sql.types.{DataType, FloatType, NumericType, StructType}
import org.apache.spark.sql.vectorized.ColumnVector

private[spark] object GpuUtils {

def extractBatchToHost(table: Table, types: Array[DataType]): Array[ColumnVector] = {
// spark-rapids has shimmed the GpuColumnVector from 22.10
try {
val clazz = Utils.classForName("com.nvidia.spark.rapids.GpuColumnVectorUtils")
clazz.getDeclaredMethod("extractHostColumns", classOf[Table], classOf[Array[DataType]])
.invoke(null, table, types).asInstanceOf[Array[ColumnVector]]
} catch {
case _: ClassNotFoundException =>
// If it's older version, use the GpuColumnVector
GpuColumnVector.extractColumns(table, types).map(_.copyToHost())
case e: Throwable => throw e
}
}

def toColumnarRdd(df: DataFrame): RDD[Table] = ColumnarRdd(df)

def seqIntToSeqInteger(x: Seq[Int]): Seq[Integer] = x.map(new Integer(_))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import java.nio.file.{Files, Path}
import java.sql.{Date, Timestamp}
import java.util.{Locale, TimeZone}

import com.nvidia.spark.rapids.RapidsConf
import org.scalatest.{BeforeAndAfterAll, FunSuite}

import org.apache.spark.SparkConf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark.util
import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints}

// based on org.apache.spark.util copy /paste
private[spark] object Utils {
object Utils {

def getSparkClassLoader: ClassLoader = getClass.getClassLoader

Expand Down

0 comments on commit 8d247f0

Please sign in to comment.