From f34cc82cd358fd2bf231994b148a52ae80a495be Mon Sep 17 00:00:00 2001 From: Nan Zhu Date: Mon, 19 Nov 2018 15:23:48 -0800 Subject: [PATCH] update example --- .../scala/example/spark/SparkTraining.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala index 13d8ecb1a25d..573630a195d4 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala @@ -40,7 +40,7 @@ object SparkTraining { StructField("petal length", DoubleType, true), StructField("petal width", DoubleType, true), StructField("class", StringType, true))) - val rawInput = spark.read.schema(schema).csv(args(0)) + val rawInput = spark.read.schema(schema).csv(inputPath) // transform class to index to make xgboost happy val stringIndexer = new StringIndexer() @@ -55,6 +55,8 @@ object SparkTraining { val xgbInput = vectorAssembler.transform(labelTransformed).select("features", "classIndex") + val Array(train, eval1, eval2, test) = xgbInput.randomSplit(Array(0.6, 0.2, 0.2, 0.1)) + /** * setup "timeout_request_workers" -> 60000L to make this application if it cannot get enough resources * to get 2 workers within 60000 ms @@ -67,12 +69,14 @@ object SparkTraining { "objective" -> "multi:softprob", "num_class" -> 3, "num_round" -> 100, - "num_workers" -> 2) + "num_workers" -> 2, + "eval_sets" -> Array(eval1, eval2), + "eval_set_names" -> Array("eval1", "eval2")) val xgbClassifier = new XGBoostClassifier(xgbParam). setFeaturesCol("features"). setLabelCol("classIndex") - val xgbClassificationModel = xgbClassifier.fit(xgbInput) - val results = xgbClassificationModel.transform(xgbInput) + val xgbClassificationModel = xgbClassifier.fit(train) + val results = xgbClassificationModel.transform(test) results.show() } }