From f6f6d012ac081098b5cdc814b42f1bddded2b784 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Sat, 19 Feb 2022 11:16:37 +0800 Subject: [PATCH] [jvm-packages] Do not repartition when nWorker = 1 --- .../scala/rapids/spark/GpuPreXGBoost.scala | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala index b6399d58cd7b..0c3521069b37 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2021 by Contributors + Copyright (c) 2021-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -397,11 +397,22 @@ object GpuPreXGBoost extends PreXGBoostProvider { // No light cost way to get number of partitions from DataFrame, so always repartition val newDF = colData.groupColName .map(gn => repartitionForGroup(gn, colData.rawDF, nWorkers)) - .getOrElse(colData.rawDF.repartition(nWorkers)) + .getOrElse(repartitionInputData(colData.rawDF, nWorkers)) name -> ColumnDataBatch(newDF, colData.colIndices, colData.groupColName) } } + private def repartitionInputData(dataFrame: DataFrame, nWorkers: Int): DataFrame = { + // We can't check dataFrame.rdd.getNumPartitions == nWorkers here, since dataFrame.rdd is + // a lazy variable. If we call it here, we will not directly extract RDD[Table] again, + // instead, we will involve Columnar -> Row -> Columnar and decrease the performance + if (nWorkers == 1) { + dataFrame.coalesce(1) + } else { + dataFrame.repartition(nWorkers) + } + } + private def repartitionForGroup( groupName: String, dataFrame: DataFrame, @@ -415,7 +426,7 @@ object GpuPreXGBoost extends PreXGBoostProvider { implicit val encoder = RowEncoder(schema) // Expand the grouped rows after repartition - groupedDF.repartition(nWorkers).mapPartitions(iter => { + repartitionInputData(groupedDF, nWorkers).mapPartitions(iter => { new Iterator[Row] { var iterInRow: Iterator[Any] = Iterator.empty