diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScanBase.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScanBase.scala index c854bd9a012c..6493d2cd3a78 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScanBase.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScanBase.scala @@ -56,10 +56,10 @@ import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, Par import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.execution.datasources.orc.OrcUtils +import org.apache.spark.sql.execution.datasources.rapids.OrcFiltersWrapper import org.apache.spark.sql.execution.datasources.v2.{EmptyPartitionReader, FilePartitionReaderFactory} import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.rapids.OrcFilters import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.{ArrayType, DataType, DecimalType, MapType, StructType} @@ -824,7 +824,7 @@ private case class GpuOrcFileFilterHandler( val readerOpts = OrcInputFormat.buildOptions( conf, orcReader, partFile.start, partFile.length) // create the search argument if we have pushed filters - OrcFilters.createFilter(fullSchema, pushedFilters).foreach { f => + OrcFiltersWrapper.createFilter(fullSchema, pushedFilters).foreach { f => readerOpts.searchArgument(f, fullSchema.fieldNames) } readerOpts diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/rapids/OrcFiltersWrapper.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/rapids/OrcFiltersWrapper.scala new file mode 100644 index 000000000000..65792c76c824 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/datasources/rapids/OrcFiltersWrapper.scala @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.rapids + +import org.apache.hadoop.hive.ql.io.sarg.SearchArgument + +import org.apache.spark.sql.execution.datasources.orc.OrcFilters +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType + +// Wrapper for Spark OrcFilters which is in private package +object OrcFiltersWrapper { + def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { + OrcFilters.createFilter(schema, filters) + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/OrcFilters.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/OrcFilters.scala deleted file mode 100644 index 2dd9973cafd0..000000000000 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/OrcFilters.scala +++ /dev/null @@ -1,278 +0,0 @@ -/* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.rapids - -import java.time.{Instant, LocalDate} - -import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument} -import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder -import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder -import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable - -import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types._ - -// This is derived from Apache Spark's OrcFilters code to avoid calling the -// Spark version. Spark's version can potentially create a search argument -// applier object that is incompatible with the orc:nohive jar that has been -// shaded as part of this project. - -/** - * Helper object for building ORC `SearchArgument`s, which are used for ORC predicate push-down. - * - * Due to limitation of ORC `SearchArgument` builder, we had to implement separate checking and - * conversion passes through the Filter to make sure we only convert predicates that are known - * to be convertible. - * - * An ORC `SearchArgument` must be built in one pass using a single builder. For example, you can't - * build `a = 1` and `b = 2` first, and then combine them into `a = 1 AND b = 2`. This is quite - * different from the cases in Spark SQL or Parquet, where complex filters can be easily built using - * existing simpler ones. - * - * The annoying part is that, `SearchArgument` builder methods like `startAnd()`, `startOr()`, and - * `startNot()` mutate internal state of the builder instance. This forces us to translate all - * convertible filters with a single builder instance. However, if we try to translate a filter - * before checking whether it can be converted or not, we may end up with a builder whose internal - * state is inconsistent in the case of an inconvertible filter. - * - * For example, to convert an `And` filter with builder `b`, we call `b.startAnd()` first, and then - * try to convert its children. Say we convert `left` child successfully, but find that `right` - * child is inconvertible. Alas, `b.startAnd()` call can't be rolled back, and `b` is inconsistent - * now. - * - * The workaround employed here is to trim the Spark filters before trying to convert them. This - * way, we can only do the actual conversion on the part of the Filter that is known to be - * convertible. - * - * P.S.: Hive seems to use `SearchArgument` together with `ExprNodeGenericFuncDesc` only. Usage of - * builder methods mentioned above can only be found in test code, where all tested filters are - * known to be convertible. - */ -object OrcFilters extends OrcFiltersBase { - - /** - * Create ORC filter as a SearchArgument instance. - */ - def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { - val dataTypeMap = OrcFilters.getSearchableTypeMap(schema, SQLConf.get.caseSensitiveAnalysis) - // Combines all convertible filters using `And` to produce a single conjunction - val conjunctionOptional = buildTree(convertibleFilters(dataTypeMap, filters)) - conjunctionOptional.map { conjunction => - // Then tries to build a single ORC `SearchArgument` for the conjunction predicate. - // The input predicate is fully convertible. There should not be any empty result in the - // following recursive method call `buildSearchArgument`. - buildSearchArgument(dataTypeMap, conjunction, newBuilder).build() - } - } - - def convertibleFilters( - dataTypeMap: Map[String, OrcPrimitiveField], - filters: Seq[Filter]): Seq[Filter] = { - import org.apache.spark.sql.sources._ - - def convertibleFiltersHelper( - filter: Filter, - canPartialPushDown: Boolean): Option[Filter] = filter match { - // At here, it is not safe to just convert one side and remove the other side - // if we do not understand what the parent filters are. - // - // Here is an example used to explain the reason. - // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to - // convert b in ('1'). If we only convert a = 2, we will end up with a filter - // NOT(a = 2), which will generate wrong results. - // - // Pushing one side of AND down is only safe to do at the top level or in the child - // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate - // can be safely removed. - case And(left, right) => - val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) - val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) - (leftResultOptional, rightResultOptional) match { - case (Some(leftResult), Some(rightResult)) => Some(And(leftResult, rightResult)) - case (Some(leftResult), None) if canPartialPushDown => Some(leftResult) - case (None, Some(rightResult)) if canPartialPushDown => Some(rightResult) - case _ => None - } - - // The Or predicate is convertible when both of its children can be pushed down. - // That is to say, if one/both of the children can be partially pushed down, the Or - // predicate can be partially pushed down as well. - // - // Here is an example used to explain the reason. - // Let's say we have - // (a1 AND a2) OR (b1 AND b2), - // a1 and b1 is convertible, while a2 and b2 is not. - // The predicate can be converted as - // (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2) - // As per the logical in And predicate, we can push down (a1 OR b1). - case Or(left, right) => - for { - lhs <- convertibleFiltersHelper(left, canPartialPushDown) - rhs <- convertibleFiltersHelper(right, canPartialPushDown) - } yield Or(lhs, rhs) - case Not(pred) => - val childResultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false) - childResultOptional.map(Not) - case other => - for (_ <- buildLeafSearchArgument(dataTypeMap, other, newBuilder())) yield other - } - filters.flatMap { filter => - convertibleFiltersHelper(filter, true) - } - } - - /** - * Get PredicateLeafType which is corresponding to the given DataType. - */ - def getPredicateLeafType(dataType: DataType): PredicateLeaf.Type = dataType match { - case BooleanType => PredicateLeaf.Type.BOOLEAN - case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG - case FloatType | DoubleType => PredicateLeaf.Type.FLOAT - case StringType => PredicateLeaf.Type.STRING - case DateType => PredicateLeaf.Type.DATE - case TimestampType => PredicateLeaf.Type.TIMESTAMP - case _: DecimalType => PredicateLeaf.Type.DECIMAL - case _ => throw new UnsupportedOperationException(s"DataType: ${dataType.catalogString}") - } - - /** - * Cast literal values for filters. - * - * We need to cast to long because ORC raises exceptions - * at 'checkLiteralType' of SearchArgumentImpl.java. - */ - private def castLiteralValue(value: Any, dataType: DataType): Any = dataType match { - case ByteType | ShortType | IntegerType | LongType => - value.asInstanceOf[Number].longValue - case FloatType | DoubleType => - value.asInstanceOf[Number].doubleValue() - case _: DecimalType => - new HiveDecimalWritable(HiveDecimal.create(value.asInstanceOf[java.math.BigDecimal])) - case _: DateType if value.isInstanceOf[LocalDate] => - toJavaDate(localDateToDays(value.asInstanceOf[LocalDate])) - case _: TimestampType if value.isInstanceOf[Instant] => - toJavaTimestamp(instantToMicros(value.asInstanceOf[Instant])) - case _ => value - } - - /** - * Build a SearchArgument and return the builder so far. - * - * @param dataTypeMap a map from the attribute name to its data type. - * @param expression the input predicates, which should be fully convertible to SearchArgument. - * @param builder the input SearchArgument.Builder. - * @return the builder so far. - */ - private def buildSearchArgument( - dataTypeMap: Map[String, OrcPrimitiveField], - expression: Filter, - builder: Builder): Builder = { - import org.apache.spark.sql.sources._ - - expression match { - case And(left, right) => - val lhs = buildSearchArgument(dataTypeMap, left, builder.startAnd()) - val rhs = buildSearchArgument(dataTypeMap, right, lhs) - rhs.end() - - case Or(left, right) => - val lhs = buildSearchArgument(dataTypeMap, left, builder.startOr()) - val rhs = buildSearchArgument(dataTypeMap, right, lhs) - rhs.end() - - case Not(child) => - buildSearchArgument(dataTypeMap, child, builder.startNot()).end() - - case other => - buildLeafSearchArgument(dataTypeMap, other, builder).getOrElse { - throw new SparkException( - "The input filter of OrcFilters.buildSearchArgument should be fully convertible.") - } - } - } - - /** - * Build a SearchArgument for a leaf predicate and return the builder so far. - * - * @param dataTypeMap a map from the attribute name to its data type. - * @param expression the input filter predicates. - * @param builder the input SearchArgument.Builder. - * @return the builder so far. - */ - private def buildLeafSearchArgument( - dataTypeMap: Map[String, OrcPrimitiveField], - expression: Filter, - builder: Builder): Option[Builder] = { - def getType(attribute: String): PredicateLeaf.Type = - getPredicateLeafType(dataTypeMap(attribute).fieldType) - - import org.apache.spark.sql.sources._ - - // NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()` - // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be - // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). - expression match { - case EqualTo(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startAnd() - .equals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) - - case EqualNullSafe(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startAnd() - .nullSafeEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) - - case LessThan(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startAnd() - .lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end()) - - case LessThanOrEqual(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startAnd() - .lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) - - case GreaterThan(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startNot() - .lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) - - case GreaterThanOrEqual(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startNot() - .lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end()) - - case IsNull(name) if dataTypeMap.contains(name) => - Some(builder.startAnd() - .isNull(dataTypeMap(name).fieldName, getType(name)).end()) - - case IsNotNull(name) if dataTypeMap.contains(name) => - Some(builder.startNot() - .isNull(dataTypeMap(name).fieldName, getType(name)).end()) - - case In(name, values) if dataTypeMap.contains(name) => - val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name).fieldType)) - Some(builder.startAnd().in(dataTypeMap(name).fieldName, getType(name), - castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) - - case _ => None - } - } -} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/OrcFiltersBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/OrcFiltersBase.scala deleted file mode 100644 index d4fb2f260d60..000000000000 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/OrcFiltersBase.scala +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Copyright (c) 2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.rapids - -import java.util.Locale - -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.sources.{And, Filter} -import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType, StructField, StructType} - -/** - * Methods that can be shared when upgrading the built-in Hive. - * - * Derived from Apache Spark to avoid depending upon it directly, - * since its API has changed between Spark versions. - */ -trait OrcFiltersBase { - - private[sql] def buildTree(filters: Seq[Filter]): Option[Filter] = { - filters match { - case Seq() => None - case Seq(filter) => Some(filter) - case Seq(filter1, filter2) => Some(And(filter1, filter2)) - case _ => // length > 2 - val (left, right) = filters.splitAt(filters.length / 2) - Some(And(buildTree(left).get, buildTree(right).get)) - } - } - - case class OrcPrimitiveField(fieldName: String, fieldType: DataType) - - /** - * This method returns a map which contains ORC field name and data type. Each key - * represents a column; `dots` are used as separators for nested columns. If any part - * of the names contains `dots`, it is quoted to avoid confusion. See - * `org.apache.spark.sql.connector.catalog.quoted` for implementation details. - * - * BinaryType, UserDefinedType, ArrayType and MapType are ignored. - */ - protected[sql] def getSearchableTypeMap( - schema: StructType, - caseSensitive: Boolean): Map[String, OrcPrimitiveField] = { - import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper - - def getPrimitiveFields( - fields: Seq[StructField], - parentFieldNames: Seq[String] = Seq.empty): Seq[(String, OrcPrimitiveField)] = { - fields.flatMap { f => - f.dataType match { - case st: StructType => - getPrimitiveFields(st.fields, parentFieldNames :+ f.name) - case BinaryType => None - case _: AtomicType => - val fieldName = (parentFieldNames :+ f.name).quoted - val orcField = OrcPrimitiveField(fieldName, f.dataType) - Some((fieldName, orcField)) - case _ => None - } - } - } - - val primitiveFields = getPrimitiveFields(schema.fields) - if (caseSensitive) { - primitiveFields.toMap - } else { - // Don't consider ambiguity here, i.e. more than one field are matched in case insensitive - // mode, just skip pushdown for these fields, they will trigger Exception when reading, - // See: SPARK-25175. - val dedupPrimitiveFields = primitiveFields - .groupBy(_._1.toLowerCase(Locale.ROOT)) - .filter(_._2.size == 1) - .mapValues(_.head._2) - CaseInsensitiveMap(dedupPrimitiveFields) - } - } -} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/OrcScanSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/OrcScanSuite.scala index 997409412fb5..a94affbf08df 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/OrcScanSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/OrcScanSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -106,6 +106,9 @@ class OrcScanSuite extends SparkQueryCompareTestSuite { * is actually 1582-09-23 in proleptic Gregorian calendar. */ test("test hybrid Julian Gregorian calendar vs proleptic Gregorian calendar") { + // After Spark 3.1.1, Orc failed to prune when converting Hybrid calendar to Proleptic calendar + // Orc bug: https://issues.apache.org/jira/browse/ORC-1083 + assumePriorToSpark311 withCpuSparkSession(spark => { val df = frameFromOrcWithSchema("hybrid-Julian-calendar.orc", diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala index 1f84b04ad773..b8357c9db159 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1835,6 +1835,9 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm { def assumeSpark320orLater: Assertion = assume(VersionUtils.isSpark320OrLater, "Spark version not 3.2.0+") + def assumePriorToSpark311: Assertion = + assume(!VersionUtils.isSpark311OrLater, "Spark version not before 3.1.1") + def cmpSparkVersion(major: Int, minor: Int, bugfix: Int): Int = { val sparkShimVersion = ShimLoader.getSparkShims.getSparkShimVersion val (sparkMajor, sparkMinor, sparkBugfix) = sparkShimVersion match {