Skip to content

Commit

Permalink
fix reflection bug: pass in different arguments for different version…
Browse files Browse the repository at this point in the history
… of same function (pingcap#1037) (pingcap#1052)

(cherry picked from commit a5462c2)
  • Loading branch information
marsishandsome authored Aug 21, 2019
1 parent 8ccc72b commit 8b6f1a3
Showing 1 changed file with 96 additions and 68 deletions.
164 changes: 96 additions & 68 deletions core/src/main/scala/com/pingcap/tispark/utils/ReflectionUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,86 +59,114 @@ object ReflectionUtil {
// isOrderSensitive: Boolean = false): RDD[U]
//
// Hereby we use reflection to support different Spark versions.
private val mapPartitionsWithIndexInternal: Method = TiSparkInfo.SPARK_VERSION match {
case "2.3.0" | "2.3.1" =>
tryLoadMethod(
"mapPartitionsWithIndexInternal",
mapPartitionsWithIndexInternalV1,
mapPartitionsWithIndexInternalV2
)
case _ =>
// Spark version >= 2.3.2
tryLoadMethod(
"mapPartitionsWithIndexInternal",
mapPartitionsWithIndexInternalV2,
mapPartitionsWithIndexInternalV1
)
}
case class ReflectionMapPartitionWithIndexInternal(
rdd: RDD[InternalRow],
internalRowToUnsafeRowWithIndex: (Int, Iterator[InternalRow]) => Iterator[UnsafeRow]
) {
// Spark HDP Release may not compatible with official Release
// see https://github.com/pingcap/tispark/issues/1006
def invoke(): RDD[InternalRow] = {
val (version, method) = TiSparkInfo.SPARK_VERSION match {
case "2.3.0" | "2.3.1" =>
try {
reflectMapPartitionsWithIndexInternalV1(rdd, internalRowToUnsafeRowWithIndex)
} catch {
case _: Throwable =>
try {
reflectMapPartitionsWithIndexInternalV2(rdd, internalRowToUnsafeRowWithIndex)
} catch {
case _: Throwable =>
throw ScalaReflectionException(
s"Cannot find reflection of Method mapPartitionsWithIndexInternal, current Spark version is %s"
.format(TiSparkInfo.SPARK_VERSION)
)
}
}

case _ =>
try {
reflectMapPartitionsWithIndexInternalV2(rdd, internalRowToUnsafeRowWithIndex)
} catch {
case _: Throwable =>
try {
reflectMapPartitionsWithIndexInternalV1(rdd, internalRowToUnsafeRowWithIndex)
} catch {
case _: Throwable =>
throw ScalaReflectionException(
s"Cannot find reflection of Method mapPartitionsWithIndexInternal, current Spark version is %s"
.format(TiSparkInfo.SPARK_VERSION)
)
}
}
}

// Spark HDP Release may not compatible with official Release
// see https://github.com/pingcap/tispark/issues/1006
private def tryLoadMethod(name: String, f1: () => Method, f2: () => Method): Method = {
try {
f1.apply()
} catch {
case _: Throwable =>
try {
f2.apply()
} catch {
case _: Throwable =>
throw ScalaReflectionException(
s"Cannot find reflection of Method $name, current Spark version is %s"
.format(TiSparkInfo.SPARK_VERSION)
)
}
invokeMapPartitionsWithIndexInternal(version, method, rdd, internalRowToUnsafeRowWithIndex)
}
}

// Spark-2.3.0 & Spark-2.3.1
private def mapPartitionsWithIndexInternalV1(): Method =
classOf[RDD[InternalRow]].getDeclaredMethod(
"mapPartitionsWithIndexInternal",
classOf[(Int, Iterator[InternalRow]) => Iterator[UnsafeRow]],
classOf[Boolean],
classOf[ClassTag[UnsafeRow]]
private def reflectMapPartitionsWithIndexInternalV1(
rdd: RDD[InternalRow],
internalRowToUnsafeRowWithIndex: (Int, Iterator[InternalRow]) => Iterator[UnsafeRow]
): (String, Method) = {
(
"v1",
classOf[RDD[InternalRow]].getDeclaredMethod(
"mapPartitionsWithIndexInternal",
classOf[(Int, Iterator[InternalRow]) => Iterator[UnsafeRow]],
classOf[Boolean],
classOf[ClassTag[UnsafeRow]]
)
)
}

// >= Spark-2.3.2
private def mapPartitionsWithIndexInternalV2(): Method =
classOf[RDD[InternalRow]].getDeclaredMethod(
"mapPartitionsWithIndexInternal",
classOf[(Int, Iterator[InternalRow]) => Iterator[UnsafeRow]],
classOf[Boolean],
classOf[Boolean],
classOf[ClassTag[UnsafeRow]]
private def reflectMapPartitionsWithIndexInternalV2(
rdd: RDD[InternalRow],
internalRowToUnsafeRowWithIndex: (Int, Iterator[InternalRow]) => Iterator[UnsafeRow]
): (String, Method) = {
(
"v2",
classOf[RDD[InternalRow]].getDeclaredMethod(
"mapPartitionsWithIndexInternal",
classOf[(Int, Iterator[InternalRow]) => Iterator[UnsafeRow]],
classOf[Boolean],
classOf[Boolean],
classOf[ClassTag[UnsafeRow]]
)
)
}

case class ReflectionMapPartitionWithIndexInternal(
private def invokeMapPartitionsWithIndexInternal(
version: String,
method: Method,
rdd: RDD[InternalRow],
internalRowToUnsafeRowWithIndex: (Int, Iterator[InternalRow]) => Iterator[UnsafeRow]
) {
def invoke(): RDD[InternalRow] =
TiSparkInfo.SPARK_VERSION match {
case "2.3.0" | "2.3.1" =>
mapPartitionsWithIndexInternal
.invoke(
rdd,
internalRowToUnsafeRowWithIndex,
Boolean.box(false),
ClassTag.apply(classOf[UnsafeRow])
)
.asInstanceOf[RDD[InternalRow]]
case _ =>
mapPartitionsWithIndexInternal
.invoke(
rdd,
internalRowToUnsafeRowWithIndex,
Boolean.box(false),
Boolean.box(false),
ClassTag.apply(classOf[UnsafeRow])
)
.asInstanceOf[RDD[InternalRow]]
}
): RDD[InternalRow] = {
version match {
case "v1" =>
// Spark-2.3.0 & Spark-2.3.1
method
.invoke(
rdd,
internalRowToUnsafeRowWithIndex,
Boolean.box(false),
ClassTag.apply(classOf[UnsafeRow])
)
.asInstanceOf[RDD[InternalRow]]

case _ =>
// >= Spark-2.3.2
method
.invoke(
rdd,
internalRowToUnsafeRowWithIndex,
Boolean.box(false),
Boolean.box(false),
ClassTag.apply(classOf[UnsafeRow])
)
.asInstanceOf[RDD[InternalRow]]
}
}

lazy val classLoader: URLClassLoader = {
Expand Down

0 comments on commit 8b6f1a3

Please sign in to comment.