diff --git a/doc/jvm/xgboost4j_spark_tutorial.rst b/doc/jvm/xgboost4j_spark_tutorial.rst index a7332f253630..fd106be7c3d1 100644 --- a/doc/jvm/xgboost4j_spark_tutorial.rst +++ b/doc/jvm/xgboost4j_spark_tutorial.rst @@ -61,9 +61,9 @@ and then refer to the snapshot dependency by adding: next_version_num-SNAPSHOT -.. note:: XGBoost4J-Spark requires Apache Spark 2.3+ +.. note:: XGBoost4J-Spark requires Apache Spark 2.4+ - XGBoost4J-Spark now requires **Apache Spark 2.3+**. Latest versions of XGBoost4J-Spark uses facilities of `org.apache.spark.ml.param.shared` extensively to provide for a tight integration with Spark MLLIB framework, and these facilities are not fully available on earlier versions of Spark. + XGBoost4J-Spark now requires **Apache Spark 2.4+**. Latest versions of XGBoost4J-Spark uses facilities of `org.apache.spark.ml.param.shared` extensively to provide for a tight integration with Spark MLLIB framework, and these facilities are not fully available on earlier versions of Spark. Also, make sure to install Spark directly from `Apache website `_. **Upstream XGBoost is not guaranteed to work with third-party distributions of Spark, such as Cloudera Spark.** Consult appropriate third parties to obtain their distribution of XGBoost. diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index 1476801c6db3..d25cf51b1deb 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -34,7 +34,7 @@ 1.7 1.7 1.5.0 - 2.3.3 + 2.4.1 2.11.12 2.11 diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsReader.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsReader.scala index b79d5b694345..bb75bb342cb1 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsReader.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DefaultXGBoostParamsReader.scala @@ -22,12 +22,17 @@ import org.json4s.JsonAST.JObject import org.json4s.jackson.JsonMethods.{compact, parse, render} import org.apache.spark.SparkContext -import org.apache.spark.ml.param.Params +import org.apache.spark.ml.param.{Param, Params} import org.apache.spark.ml.util.MLReader // This originates from apache-spark DefaultPramsReader copy paste private[spark] object DefaultXGBoostParamsReader { + private val paramNameCompatibilityMap: Map[String, String] = Map("silent" -> "verbosity") + + private val paramValueCompatibilityMap: Map[String, Map[Any, Any]] = + Map("objective" -> Map("reg:linear" -> "reg:squarederror")) + /** * All info from metadata file. * @@ -103,6 +108,14 @@ private[spark] object DefaultXGBoostParamsReader { Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr) } + private def handleBrokenlyChangedValue[T](paramName: String, value: T): T = { + paramValueCompatibilityMap.getOrElse(paramName, Map()).getOrElse(value, value).asInstanceOf[T] + } + + private def handleBrokenlyChangedName(paramName: String): String = { + paramNameCompatibilityMap.getOrElse(paramName, paramName) + } + /** * Extract Params from metadata, and set them in the instance. * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]]. @@ -113,9 +126,9 @@ private[spark] object DefaultXGBoostParamsReader { metadata.params match { case JObject(pairs) => pairs.foreach { case (paramName, jsonValue) => - val param = instance.getParam(paramName) + val param = instance.getParam(handleBrokenlyChangedName(paramName)) val value = param.jsonDecode(compact(render(jsonValue))) - instance.set(param, value) + instance.set(param, handleBrokenlyChangedValue(paramName, value)) } case _ => throw new IllegalArgumentException( diff --git a/jvm-packages/xgboost4j-spark/src/test/resources/model/0.82/model/data/XGBoostClassificationModel b/jvm-packages/xgboost4j-spark/src/test/resources/model/0.82/model/data/XGBoostClassificationModel new file mode 100644 index 000000000000..5d915d02f5f8 Binary files /dev/null and b/jvm-packages/xgboost4j-spark/src/test/resources/model/0.82/model/data/XGBoostClassificationModel differ diff --git a/jvm-packages/xgboost4j-spark/src/test/resources/model/0.82/model/metadata/_SUCCESS b/jvm-packages/xgboost4j-spark/src/test/resources/model/0.82/model/metadata/_SUCCESS new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/jvm-packages/xgboost4j-spark/src/test/resources/model/0.82/model/metadata/part-00000 b/jvm-packages/xgboost4j-spark/src/test/resources/model/0.82/model/metadata/part-00000 new file mode 100644 index 000000000000..7e1a7221ace3 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/resources/model/0.82/model/metadata/part-00000 @@ -0,0 +1 @@ +{"class":"ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel","timestamp":1555350539033,"sparkVersion":"2.3.2-uber-109","uid":"xgbc_5e7bec215a4c","paramMap":{"useExternalMemory":false,"trainTestRatio":1.0,"alpha":0.0,"seed":0,"numWorkers":100,"skipDrop":0.0,"treeLimit":0,"silent":0,"trackerConf":{"workerConnectionTimeout":0,"trackerImpl":"python"},"missing":"NaN","colsampleBylevel":1.0,"probabilityCol":"probability","checkpointPath":"","lambda":1.0,"rawPredictionCol":"rawPrediction","eta":0.3,"numEarlyStoppingRounds":0,"growPolicy":"depthwise","gamma":0.0,"sampleType":"uniform","maxDepth":6,"rateDrop":0.0,"objective":"reg:linear","customObj":null,"lambdaBias":0.0,"baseScore":0.5,"labelCol":"label","minChildWeight":1.0,"customEval":null,"normalizeType":"tree","maxBin":16,"nthread":4,"numRound":20,"colsampleBytree":1.0,"predictionCol":"prediction","subsample":1.0,"timeoutRequestWorkers":1800000,"featuresCol":"features","evalMetric":"error","sketchEps":0.03,"scalePosWeight":1.0,"checkpointInterval":-1,"maxDeltaStep":0.0,"treeMethod":"approx"}} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala index 4c0b21073f2f..220cea307ff4 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala @@ -19,9 +19,11 @@ package ml.dmlc.xgboost4j.scala.spark import java.io.{File, FileNotFoundException} import java.util.Arrays -import ml.dmlc.xgboost4j.scala.DMatrix +import scala.io.Source +import ml.dmlc.xgboost4j.scala.DMatrix import scala.util.Random + import org.apache.spark.ml.feature._ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.network.util.JavaUtils @@ -162,5 +164,17 @@ class PersistenceSuite extends FunSuite with PerTest with BeforeAndAfterAll { assert(xgbModel.getNumRound === xgbModel2.getNumRound) assert(xgbModel.getRawPredictionCol === xgbModel2.getRawPredictionCol) } + + test("cross-version model loading (0.82)") { + val modelPath = getClass.getResource("/model/0.82/model").getPath + val model = XGBoostClassificationModel.read.load(modelPath) + val r = new Random(0) + val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))). + toDF("feature", "label") + val assembler = new VectorAssembler() + .setInputCols(df.columns.filter(!_.contains("label"))) + .setOutputCol("features") + model.transform(assembler.transform(df)).show() + } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index 50f827c34bc8..1affe1474f2a 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -261,6 +261,7 @@ class XGBoostGeneralSuite extends FunSuite with PerTest { val vectorAssembler = new VectorAssembler() .setInputCols(Array("col1", "col2", "col3")) .setOutputCol("features") + .setHandleInvalid("keep") val inputDF = vectorAssembler.transform(testDF).select("features", "label") val paramMap = List("eta" -> "1", "max_depth" -> "2", "objective" -> "binary:logistic", "num_workers" -> 1).toMap