-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[jvm-packages] delete all constraints from spark layer about obj and …
…eval metrics and handle error in jvm layer (#4560) * temp * prediction part * remove supported* * add for test * fix param name * add rabit * update rabit * return value of rabit init * eliminate compilation warnings * update rabit * shutdown * update rabit again * check sparkcontext shutdown * fix logic * sleep * fix tests * test with relaxed threshold * create new thread each time * stop for job quitting * udpate rabit * update rabit * update rabit * update git modules
- Loading branch information
Showing
11 changed files
with
142 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
81 changes: 81 additions & 0 deletions
81
...ackages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
/* | ||
Copyright (c) 2014 by Contributors | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package ml.dmlc.xgboost4j.scala.spark | ||
|
||
import ml.dmlc.xgboost4j.java.XGBoostError | ||
import org.scalatest.{BeforeAndAfterAll, FunSuite} | ||
|
||
import org.apache.spark.ml.param.ParamMap | ||
|
||
class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll { | ||
|
||
test("XGBoost and Spark parameters synchronize correctly") { | ||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic", | ||
"objective_type" -> "classification") | ||
// from xgboost params to spark params | ||
val xgb = new XGBoostClassifier(xgbParamMap) | ||
assert(xgb.getEta === 1.0) | ||
assert(xgb.getObjective === "binary:logistic") | ||
assert(xgb.getObjectiveType === "classification") | ||
// from spark to xgboost params | ||
val xgbCopy = xgb.copy(ParamMap.empty) | ||
assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0) | ||
assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic") | ||
assert(xgbCopy.MLlib2XGBoostParams("objective_type").toString === "classification") | ||
val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss")) | ||
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss") | ||
} | ||
|
||
private def waitForSparkContextShutdown(): Unit = { | ||
var totalWaitedTime = 0L | ||
while (!ss.sparkContext.isStopped && totalWaitedTime <= 120000) { | ||
Thread.sleep(10000) | ||
totalWaitedTime += 10000 | ||
} | ||
assert(ss.sparkContext.isStopped === true) | ||
} | ||
|
||
test("fail training elegantly with unsupported objective function") { | ||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", | ||
"objective" -> "wrong_objective_function", "num_class" -> "6", "num_round" -> 5, | ||
"num_workers" -> numWorkers) | ||
val trainingDF = buildDataFrame(MultiClassification.train) | ||
val xgb = new XGBoostClassifier(paramMap) | ||
try { | ||
val model = xgb.fit(trainingDF) | ||
} catch { | ||
case e: Throwable => // swallow anything | ||
} finally { | ||
waitForSparkContextShutdown() | ||
} | ||
} | ||
|
||
test("fail training elegantly with unsupported eval metrics") { | ||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", | ||
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5, | ||
"num_workers" -> numWorkers, "eval_metric" -> "wrong_eval_metrics") | ||
val trainingDF = buildDataFrame(MultiClassification.train) | ||
val xgb = new XGBoostClassifier(paramMap) | ||
try { | ||
val model = xgb.fit(trainingDF) | ||
} catch { | ||
case e: Throwable => // swallow anything | ||
} finally { | ||
waitForSparkContextShutdown() | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Submodule rabit
updated
13 files
+4 −2 | include/rabit/c_api.h | |
+2 −2 | include/rabit/internal/engine.h | |
+4 −4 | include/rabit/internal/rabit-inl.h | |
+19 −4 | include/rabit/internal/utils.h | |
+4 −2 | include/rabit/rabit.h | |
+184 −151 | src/allreduce_base.cc | |
+3 −3 | src/allreduce_base.h | |
+27 −17 | src/allreduce_robust.cc | |
+2 −2 | src/allreduce_robust.h | |
+4 −4 | src/c_api.cc | |
+12 −6 | src/engine.cc | |
+9 −2 | src/engine_empty.cc | |
+22 −5 | src/engine_mpi.cc |