diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePushdownFilterQueries.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePushdownFilterQueries.java index 7eebcd4bf5e6..fe7f264b81a9 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePushdownFilterQueries.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePushdownFilterQueries.java @@ -823,6 +823,17 @@ public void testSchemaEvolution() { assertUpdate("CREATE TABLE test_schema_evolution WITH (partitioned_by = ARRAY['regionkey']) AS SELECT nationkey, regionkey FROM nation", 25); assertUpdate("ALTER TABLE test_schema_evolution ADD COLUMN nation_plus_region BIGINT"); + + // constant filter function errors + assertQueryFails("SELECT * FROM test_schema_evolution WHERE coalesce(nation_plus_region, fail('constant filter error')) is not null", "constant filter error"); + assertQuerySucceeds("SELECT * FROM test_schema_evolution WHERE nationkey < 0 AND coalesce(nation_plus_region, fail('constant filter error')) is not null"); + assertQueryFails("SELECT * FROM test_schema_evolution WHERE nationkey % 2 = 0 AND coalesce(nation_plus_region, fail('constant filter error')) is not null", "constant filter error"); + + // non-deterministic filter function with constant inputs + assertQueryReturnsEmptyResult("SELECT * FROM test_schema_evolution WHERE nation_plus_region * rand() < 0"); + assertQuery("SELECT nationkey FROM test_schema_evolution WHERE nation_plus_region * rand() IS NULL", "SELECT nationkey FROM nation"); + assertQuerySucceeds("SELECT nationkey FROM test_schema_evolution WHERE coalesce(nation_plus_region, 1) * rand() < 0.5"); + assertUpdate("INSERT INTO test_schema_evolution SELECT nationkey, nationkey + regionkey, regionkey FROM nation", 25); assertUpdate("ALTER TABLE test_schema_evolution ADD COLUMN nation_minus_region BIGINT"); assertUpdate("INSERT INTO test_schema_evolution SELECT nationkey, nationkey + regionkey, nationkey - regionkey, regionkey FROM nation", 25); diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/OrcSelectiveRecordReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/OrcSelectiveRecordReader.java index 8b796d7638ff..b6204598441f 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/OrcSelectiveRecordReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/OrcSelectiveRecordReader.java @@ -39,6 +39,7 @@ import com.google.common.collect.Iterables; import com.google.common.collect.Maps; import com.google.common.io.Closer; +import com.google.common.primitives.Ints; import io.airlift.slice.Slice; import io.airlift.units.DataSize; import org.joda.time.DateTimeZone; @@ -47,6 +48,7 @@ import java.io.IOException; import java.io.UncheckedIOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.List; @@ -68,10 +70,12 @@ import static com.facebook.presto.spi.type.TinyintType.TINYINT; import static com.facebook.presto.spi.type.Varchars.isVarcharType; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Predicates.not; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.lang.Math.max; import static java.util.Objects.requireNonNull; public class OrcSelectiveRecordReader @@ -86,24 +90,46 @@ public class OrcSelectiveRecordReader private final Map columnTypes; // key: index into hiveColumnIndices array private final Object[] constantValues; // aligned with hiveColumnIndices array private final Function[] coercers; // aligned with hiveColumnIndices array - private final List filterFunctionsWithInputs; + + // non-deterministic filter function with no inputs (rand() < 0.1); evaluated before any column is read private final Optional filterFunctionWithoutInput; private final Map filterFunctionInputMapping; // channel-to-index-into-hiveColumnIndices-array mapping - private final Set filterFunctionInputs; // channels private final Map columnsWithFilterScores; // keys are indices into hiveColumnIndices array; values are filter scores // Optimal order of stream readers private int[] streamReaderOrder; // elements are indices into hiveColumnIndices array + private List[] filterFunctionsOrder; // aligned with streamReaderOrder order; each filter function is placed + // into a list positioned at the last necessary input + private Set[] filterFunctionInputs; // aligned with filterFunctionsOrder + + // non-deterministic filter functions with only constant inputs; evaluated before any column is read + private List filterFunctionsWithConstantInputs; + private Set filterFunctionConstantInputs; + // An immutable list of initial positions; includes all positions: 0,1,2,3,4,.. // This array may grow, but cannot shrink. The values don't change. private int[] positions; // Used in applyFilterFunctions; mutable private int[] outputPositions; + + // errors encountered while evaluating filter functions; indices are positions in the batch + // of rows being processed by getNextPage (errors[outputPositions[i]] is valid) private RuntimeException[] errors; + + // temporary array to be used in applyFilterFunctions only; exists solely for the purpose of re-using memory + // indices are positions in a page provided to the filter filters (it contains a subset of rows that passed earlier filters) + private RuntimeException[] tmpErrors; + + // flag indicating whether range filter on a constant column is false; no data is read in that case private boolean constantFilterIsFalse; + // an error occurred while evaluating deterministic filter function with only constant + // inputs; thrown unless other filters eliminate all rows + @Nullable + private RuntimeException constantFilterError; + private int readPositions; public OrcSelectiveRecordReader( @@ -185,15 +211,12 @@ public OrcSelectiveRecordReader( this.outputColumns = outputColumns.stream().map(zeroBasedIndices::get).collect(toImmutableList()); this.columnTypes = includedColumns.entrySet().stream().collect(toImmutableMap(entry -> zeroBasedIndices.get(entry.getKey()), Map.Entry::getValue)); this.filterFunctionWithoutInput = getFilterFunctionWithoutInputs(filterFunctions); - this.filterFunctionsWithInputs = filterFunctions.stream() - .filter(function -> function.getInputChannels().length > 0) - .collect(toImmutableList()); - this.filterFunctionInputMapping = Maps.transformValues(filterFunctionInputMapping, zeroBasedIndices::get); - this.filterFunctionInputs = filterFunctions.stream() + + Set usedInputChannels = filterFunctions.stream() .flatMapToInt(function -> Arrays.stream(function.getInputChannels())) .boxed() - .map(this.filterFunctionInputMapping::get) .collect(toImmutableSet()); + this.filterFunctionInputMapping = Maps.transformValues(Maps.filterKeys(filterFunctionInputMapping, usedInputChannels::contains), zeroBasedIndices::get); this.columnsWithFilterScores = filters .entrySet() .stream() @@ -229,19 +252,124 @@ public OrcSelectiveRecordReader( } } + if (!evaluateDeterministicFilterFunctionsWithConstantInputs(filterFunctions)) { + constantFilterIsFalse = true; + // No further initialization needed. + return; + } + // Initial order of stream readers is: // - readers with integer equality // - readers with integer range / multivalues / inequality // - readers with filters // - followed by readers for columns that provide input to filter functions // - followed by readers for columns that doesn't have any filtering - streamReaderOrder = orderStreamReaders(columnTypes.keySet().stream().filter(index -> this.constantValues[index] == null).collect(toImmutableSet()), columnsWithFilterScores, filterFunctionInputs, columnTypes); + streamReaderOrder = orderStreamReaders(columnTypes.keySet().stream().filter(index -> this.constantValues[index] == null).collect(toImmutableSet()), columnsWithFilterScores, this.filterFunctionInputMapping.keySet(), columnTypes); + + List filterFunctionsWithInputs = filterFunctions.stream() + .filter(OrcSelectiveRecordReader::hasInputs) + .filter(not(this::allConstantInputs)) + .collect(toImmutableList()); + + // figure out when to evaluate filter functions; a function is ready for evaluation as soon as the last input has been read + filterFunctionsOrder = orderFilterFunctionsWithInputs(streamReaderOrder, filterFunctionsWithInputs, this.filterFunctionInputMapping); + filterFunctionInputs = collectFilterFunctionInputs(filterFunctionsOrder, this.filterFunctionInputMapping); + + filterFunctionsWithConstantInputs = filterFunctions.stream() + .filter(not(FilterFunction::isDeterministic)) + .filter(OrcSelectiveRecordReader::hasInputs) + .filter(this::allConstantInputs) + .collect(toImmutableList()); + filterFunctionConstantInputs = filterFunctionsWithConstantInputs.stream() + .flatMapToInt(function -> Arrays.stream(function.getInputChannels())) + .boxed() + .map(this.filterFunctionInputMapping::get) + .collect(toImmutableSet()); + } + + private boolean evaluateDeterministicFilterFunctionsWithConstantInputs(List filterFunctions) + { + for (FilterFunction function : filterFunctions) { + if (function.isDeterministic() && hasInputs(function) && allConstantInputs(function) && !evaluateDeterministicFilterFunctionWithConstantInputs(function)) { + return false; + } + } + return true; + } + + private boolean evaluateDeterministicFilterFunctionWithConstantInputs(FilterFunction function) + { + int[] inputs = function.getInputChannels(); + Block[] blocks = new Block[inputs.length]; + for (int i = 0; i < inputs.length; i++) { + int columnIndex = filterFunctionInputMapping.get(inputs[i]); + Object constantValue = constantValues[columnIndex]; + blocks[i] = RunLengthEncodedBlock.create(columnTypes.get(columnIndex), constantValue == NULL_MARKER ? null : constantValue, 1); + } + + initializeTmpErrors(1); + int positionCount = function.filter(new Page(blocks), new int[] {0}, 1, tmpErrors); + + if (tmpErrors[0] != null) { + constantFilterError = tmpErrors[0]; + } + return positionCount == 1; + } + + private static boolean hasInputs(FilterFunction function) + { + return function.getInputChannels().length > 0; + } + + private boolean allConstantInputs(FilterFunction function) + { + return Arrays.stream(function.getInputChannels()) + .map(filterFunctionInputMapping::get) + .allMatch(columnIndex -> constantValues[columnIndex] != null); + } + + private static List[] orderFilterFunctionsWithInputs(int[] streamReaderOrder, List filterFunctions, Map inputMapping) + { + List[] order = new List[streamReaderOrder.length]; + for (FilterFunction function : filterFunctions) { + int[] inputs = function.getInputChannels(); + int lastIndex = -1; + for (int input : inputs) { + int columnIndex = inputMapping.get(input); + lastIndex = max(lastIndex, Ints.indexOf(streamReaderOrder, columnIndex)); + } + + verify(lastIndex >= 0); + if (order[lastIndex] == null) { + order[lastIndex] = new ArrayList<>(); + } + order[lastIndex].add(function); + } + + return order; + } + + private static Set[] collectFilterFunctionInputs(List[] functionsOrder, Map inputMapping) + { + Set[] inputs = new Set[functionsOrder.length]; + for (int i = 0; i < functionsOrder.length; i++) { + List functions = functionsOrder[i]; + if (functions != null) { + inputs[i] = functions.stream() + .flatMapToInt(function -> Arrays.stream(function.getInputChannels())) + .boxed() + .map(inputMapping::get) + .collect(toImmutableSet()); + } + } + + return inputs; } private static Optional getFilterFunctionWithoutInputs(List filterFunctions) { List functions = filterFunctions.stream() - .filter(function -> function.getInputChannels().length == 0) + .filter(not(OrcSelectiveRecordReader::hasInputs)) .collect(toImmutableList()); if (functions.isEmpty()) { return Optional.empty(); @@ -429,20 +557,22 @@ public Page getNextPage() positionsToRead = outputPositions; } - boolean filterFunctionsApplied = filterFunctionsWithInputs.isEmpty(); - int offset = getNextRowInGroup(); + if (!filterFunctionsWithConstantInputs.isEmpty()) { + positionCount = applyFilterFunctions(filterFunctionsWithConstantInputs, filterFunctionConstantInputs, positionsToRead, positionCount); - for (int columnIndex : streamReaderOrder) { - if (!filterFunctionsApplied && !hasAnyFilter(columnIndex)) { - positionCount = applyFilterFunctions(positionsToRead, positionCount); - if (positionCount == 0) { - break; - } - - positionsToRead = outputPositions; - filterFunctionsApplied = true; + if (positionCount == 0) { + batchRead(batchSize); + return EMPTY_PAGE; } + positionsToRead = outputPositions; + } + + int offset = getNextRowInGroup(); + + for (int i = 0; i < streamReaderOrder.length; i++) { + int columnIndex = streamReaderOrder[i]; + if (!hasAnyFilter(columnIndex)) { break; } @@ -455,11 +585,15 @@ public Page getNextPage() positionsToRead = streamReader.getReadPositions(); verify(positionCount == 1 || positionsToRead[positionCount - 1] - positionsToRead[0] >= positionCount - 1, "positions must monotonically increase"); - } - if (positionCount > 0 && !filterFunctionsApplied) { - positionCount = applyFilterFunctions(positionsToRead, positionCount); - positionsToRead = outputPositions; + if (filterFunctionsOrder[i] != null) { + positionCount = applyFilterFunctions(filterFunctionsOrder[i], filterFunctionInputs[i], positionsToRead, positionCount); + if (positionCount == 0) { + break; + } + + positionsToRead = outputPositions; + } } batchRead(batchSize); @@ -468,11 +602,13 @@ public Page getNextPage() return EMPTY_PAGE; } - if (filterFunctionsWithInputs.isEmpty() && filterFunctionWithoutInput.isPresent()) { - for (int i = 0; i < positionCount; i++) { - if (errors[positionsToRead[i]] != null) { - throw errors[positionsToRead[i]]; - } + if (constantFilterError != null) { + throw constantFilterError; + } + + for (int i = 0; i < positionCount; i++) { + if (errors[positionsToRead[i]] != null) { + throw errors[positionsToRead[i]]; } } @@ -517,7 +653,7 @@ private SelectiveStreamReader getStreamReader(int columnIndex) private boolean hasAnyFilter(int columnIndex) { - return columnsWithFilterScores.containsKey(columnIndex) || filterFunctionInputs.contains(columnIndex); + return columnsWithFilterScores.containsKey(columnIndex) || filterFunctionInputMapping.containsKey(columnIndex); } private void initializePositions(int batchSize) @@ -544,7 +680,7 @@ private int applyFilterFunctionWithNoInputs(int positionCount) return filterFunctionWithoutInput.get().filter(page, outputPositions, positionCount, errors); } - private int applyFilterFunctions(int[] positions, int positionCount) + private int applyFilterFunctions(List filterFunctions, Set filterFunctionInputs, int[] positions, int positionCount) { BlockLease[] blockLeases = new BlockLease[hiveColumnIndices.length]; Block[] blocks = new Block[hiveColumnIndices.length]; @@ -562,16 +698,17 @@ private int applyFilterFunctions(int[] positions, int positionCount) } } - if (filterFunctionWithoutInput.isPresent()) { - for (int i = 0; i < positionCount; i++) { - errors[i] = errors[positions[i]]; - } + initializeTmpErrors(positionCount); + for (int i = 0; i < positionCount; i++) { + tmpErrors[i] = errors[positions[i]]; } + Arrays.fill(errors, null); + try { initializeOutputPositions(positionCount); - for (FilterFunction function : filterFunctionsWithInputs) { + for (FilterFunction function : filterFunctions) { int[] inputs = function.getInputChannels(); Block[] inputBlocks = new Block[inputs.length]; @@ -580,23 +717,18 @@ private int applyFilterFunctions(int[] positions, int positionCount) } Page page = new Page(positionCount, inputBlocks); - positionCount = function.filter(page, outputPositions, positionCount, errors); + positionCount = function.filter(page, outputPositions, positionCount, tmpErrors); if (positionCount == 0) { break; } } - for (int i = 0; i < positionCount; i++) { - if (errors[i] != null) { - throw errors[i]; - } - } - // at this point outputPositions are relative to page, e.g. they are indices into positions array // translate outputPositions to positions relative to the start of the row group, // e.g. make outputPositions a subset of positions array for (int i = 0; i < positionCount; i++) { outputPositions[i] = positions[outputPositions[i]]; + errors[outputPositions[i]] = tmpErrors[i]; } return positionCount; } @@ -609,6 +741,16 @@ private int applyFilterFunctions(int[] positions, int positionCount) } } + private void initializeTmpErrors(int positionCount) + { + if (tmpErrors == null || tmpErrors.length < positionCount) { + tmpErrors = new RuntimeException[positionCount]; + } + else { + Arrays.fill(tmpErrors, null); + } + } + private void initializeOutputPositions(int positionCount) { if (outputPositions == null || outputPositions.length < positionCount) {