diff --git a/core/src/main/scala/com/pingcap/tispark/utils/ReflectionUtil.scala b/core/src/main/scala/com/pingcap/tispark/utils/ReflectionUtil.scala index 9e4522cd57..46bab8145a 100644 --- a/core/src/main/scala/com/pingcap/tispark/utils/ReflectionUtil.scala +++ b/core/src/main/scala/com/pingcap/tispark/utils/ReflectionUtil.scala @@ -59,86 +59,114 @@ object ReflectionUtil { // isOrderSensitive: Boolean = false): RDD[U] // // Hereby we use reflection to support different Spark versions. - private val mapPartitionsWithIndexInternal: Method = TiSparkInfo.SPARK_VERSION match { - case "2.3.0" | "2.3.1" => - tryLoadMethod( - "mapPartitionsWithIndexInternal", - mapPartitionsWithIndexInternalV1, - mapPartitionsWithIndexInternalV2 - ) - case _ => - // Spark version >= 2.3.2 - tryLoadMethod( - "mapPartitionsWithIndexInternal", - mapPartitionsWithIndexInternalV2, - mapPartitionsWithIndexInternalV1 - ) - } + case class ReflectionMapPartitionWithIndexInternal( + rdd: RDD[InternalRow], + internalRowToUnsafeRowWithIndex: (Int, Iterator[InternalRow]) => Iterator[UnsafeRow] + ) { + // Spark HDP Release may not compatible with official Release + // see https://github.com/pingcap/tispark/issues/1006 + def invoke(): RDD[InternalRow] = { + val (version, method) = TiSparkInfo.SPARK_VERSION match { + case "2.3.0" | "2.3.1" => + try { + reflectMapPartitionsWithIndexInternalV1(rdd, internalRowToUnsafeRowWithIndex) + } catch { + case _: Throwable => + try { + reflectMapPartitionsWithIndexInternalV2(rdd, internalRowToUnsafeRowWithIndex) + } catch { + case _: Throwable => + throw ScalaReflectionException( + s"Cannot find reflection of Method mapPartitionsWithIndexInternal, current Spark version is %s" + .format(TiSparkInfo.SPARK_VERSION) + ) + } + } + + case _ => + try { + reflectMapPartitionsWithIndexInternalV2(rdd, internalRowToUnsafeRowWithIndex) + } catch { + case _: Throwable => + try { + reflectMapPartitionsWithIndexInternalV1(rdd, internalRowToUnsafeRowWithIndex) + } catch { + case _: Throwable => + throw ScalaReflectionException( + s"Cannot find reflection of Method mapPartitionsWithIndexInternal, current Spark version is %s" + .format(TiSparkInfo.SPARK_VERSION) + ) + } + } + } - // Spark HDP Release may not compatible with official Release - // see https://github.com/pingcap/tispark/issues/1006 - private def tryLoadMethod(name: String, f1: () => Method, f2: () => Method): Method = { - try { - f1.apply() - } catch { - case _: Throwable => - try { - f2.apply() - } catch { - case _: Throwable => - throw ScalaReflectionException( - s"Cannot find reflection of Method $name, current Spark version is %s" - .format(TiSparkInfo.SPARK_VERSION) - ) - } + invokeMapPartitionsWithIndexInternal(version, method, rdd, internalRowToUnsafeRowWithIndex) } } // Spark-2.3.0 & Spark-2.3.1 - private def mapPartitionsWithIndexInternalV1(): Method = - classOf[RDD[InternalRow]].getDeclaredMethod( - "mapPartitionsWithIndexInternal", - classOf[(Int, Iterator[InternalRow]) => Iterator[UnsafeRow]], - classOf[Boolean], - classOf[ClassTag[UnsafeRow]] + private def reflectMapPartitionsWithIndexInternalV1( + rdd: RDD[InternalRow], + internalRowToUnsafeRowWithIndex: (Int, Iterator[InternalRow]) => Iterator[UnsafeRow] + ): (String, Method) = { + ( + "v1", + classOf[RDD[InternalRow]].getDeclaredMethod( + "mapPartitionsWithIndexInternal", + classOf[(Int, Iterator[InternalRow]) => Iterator[UnsafeRow]], + classOf[Boolean], + classOf[ClassTag[UnsafeRow]] + ) ) + } // >= Spark-2.3.2 - private def mapPartitionsWithIndexInternalV2(): Method = - classOf[RDD[InternalRow]].getDeclaredMethod( - "mapPartitionsWithIndexInternal", - classOf[(Int, Iterator[InternalRow]) => Iterator[UnsafeRow]], - classOf[Boolean], - classOf[Boolean], - classOf[ClassTag[UnsafeRow]] + private def reflectMapPartitionsWithIndexInternalV2( + rdd: RDD[InternalRow], + internalRowToUnsafeRowWithIndex: (Int, Iterator[InternalRow]) => Iterator[UnsafeRow] + ): (String, Method) = { + ( + "v2", + classOf[RDD[InternalRow]].getDeclaredMethod( + "mapPartitionsWithIndexInternal", + classOf[(Int, Iterator[InternalRow]) => Iterator[UnsafeRow]], + classOf[Boolean], + classOf[Boolean], + classOf[ClassTag[UnsafeRow]] + ) ) + } - case class ReflectionMapPartitionWithIndexInternal( + private def invokeMapPartitionsWithIndexInternal( + version: String, + method: Method, rdd: RDD[InternalRow], internalRowToUnsafeRowWithIndex: (Int, Iterator[InternalRow]) => Iterator[UnsafeRow] - ) { - def invoke(): RDD[InternalRow] = - TiSparkInfo.SPARK_VERSION match { - case "2.3.0" | "2.3.1" => - mapPartitionsWithIndexInternal - .invoke( - rdd, - internalRowToUnsafeRowWithIndex, - Boolean.box(false), - ClassTag.apply(classOf[UnsafeRow]) - ) - .asInstanceOf[RDD[InternalRow]] - case _ => - mapPartitionsWithIndexInternal - .invoke( - rdd, - internalRowToUnsafeRowWithIndex, - Boolean.box(false), - Boolean.box(false), - ClassTag.apply(classOf[UnsafeRow]) - ) - .asInstanceOf[RDD[InternalRow]] - } + ): RDD[InternalRow] = { + version match { + case "v1" => + // Spark-2.3.0 & Spark-2.3.1 + method + .invoke( + rdd, + internalRowToUnsafeRowWithIndex, + Boolean.box(false), + ClassTag.apply(classOf[UnsafeRow]) + ) + .asInstanceOf[RDD[InternalRow]] + + case _ => + // >= Spark-2.3.2 + method + .invoke( + rdd, + internalRowToUnsafeRowWithIndex, + Boolean.box(false), + Boolean.box(false), + ClassTag.apply(classOf[UnsafeRow]) + ) + .asInstanceOf[RDD[InternalRow]] + } } lazy val classLoader: URLClassLoader = {