diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala index 3b7fe4a65fea..dc2a52c4dd81 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala @@ -89,21 +89,22 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest { val workerCount: Int = numWorkers - rdd.mapPartitions { iter => + val dummyTasks = rdd.mapPartitions { iter => val index = iter.next() Communicator.init(trackerEnvs) - val a = Array(1.0f, 2.0f, 3.0f) - val b = Communicator.allReduce(a, Communicator.OpType.SUM) - for (i <- 0 to 2) { - assert(a(i) * workerCount == b(i)) - } - val c = Communicator.allReduce(a, Communicator.OpType.MIN); - for (i <- 0 to 2) { - assert(a(i) == c(i)) - } Communicator.shutdown() Iterator(index) - }.collect() + }.cache() + + val sparkThread = new Thread() { + override def run(): Unit = { + // forces a Spark job. + dummyTasks.foreachPartition(() => _) + } + } + + sparkThread.setUncaughtExceptionHandler(tracker) + sparkThread.start() } test("should allow the dataframe containing communicator calls to be partially evaluated for" +