Skip to content

Commit

Permalink
[jvm-packages] update spark dependency to 3.0.0 (#5836)
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 authored Jul 13, 2020
1 parent 23e2c6e commit 9f85e92
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 10 deletions.
2 changes: 1 addition & 1 deletion jvm-packages/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<flink.version>1.7.2</flink.version>
<spark.version>2.4.3</spark.version>
<spark.version>3.0.0</spark.version>
<scala.version>2.12.8</scala.version>
<scala.binary.version>2.12</scala.binary.version>
<hadoop.version>2.7.3</hadoop.version>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ class XGBoostClassificationModel private[ml](
}

// Actually we don't use this function at all, to make it pass compiler check.
override protected def predictRaw(features: Vector): Vector = {
override def predictRaw(features: Vector): Vector = {
throw new Exception("XGBoost-Spark does not support \'predictRaw\'")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,8 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
val vectorAssembler = new VectorAssembler()
.setInputCols(Array("col1", "col2", "col3"))
.setOutputCol("features")
org.apache.spark.SPARK_VERSION match {
case version if version.startsWith("2.4") =>
val m = vectorAssembler.getClass.getDeclaredMethods
.filter(_.getName.contains("setHandleInvalid")).head
m.invoke(vectorAssembler, "keep")
case _ =>
}
.setHandleInvalid("keep")

val inputDF = vectorAssembler.transform(testDF).select("features", "label")
val paramMap = List("eta" -> "1", "max_depth" -> "2",
"objective" -> "binary:logistic", "missing" -> Float.NaN, "num_workers" -> 1).toMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)

val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}, DataTypes.FloatType)
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f})
val trainingDF = buildDataFrame(Regression.train)
.withColumn("weight", getWeightFromId(col("id")))
val testDF = buildDataFrame(Regression.test)
Expand Down

0 comments on commit 9f85e92

Please sign in to comment.