Skip to content

Commit

Permalink
Evaluate filter function as soon as all inputs have been read
Browse files Browse the repository at this point in the history
  • Loading branch information
mbasmanova committed Dec 10, 2019
1 parent 5589d3d commit ac93bd6
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,17 @@ public void testSchemaEvolution()
{
assertUpdate("CREATE TABLE test_schema_evolution WITH (partitioned_by = ARRAY['regionkey']) AS SELECT nationkey, regionkey FROM nation", 25);
assertUpdate("ALTER TABLE test_schema_evolution ADD COLUMN nation_plus_region BIGINT");

// constant filter function errors
assertQueryFails("SELECT * FROM test_schema_evolution WHERE coalesce(nation_plus_region, fail('constant filter error')) is not null", "constant filter error");
assertQuerySucceeds("SELECT * FROM test_schema_evolution WHERE nationkey < 0 AND coalesce(nation_plus_region, fail('constant filter error')) is not null");
assertQueryFails("SELECT * FROM test_schema_evolution WHERE nationkey % 2 = 0 AND coalesce(nation_plus_region, fail('constant filter error')) is not null", "constant filter error");

// non-deterministic filter function with constant inputs
assertQueryReturnsEmptyResult("SELECT * FROM test_schema_evolution WHERE nation_plus_region * rand() < 0");
assertQuery("SELECT nationkey FROM test_schema_evolution WHERE nation_plus_region * rand() IS NULL", "SELECT nationkey FROM nation");
assertQuerySucceeds("SELECT nationkey FROM test_schema_evolution WHERE coalesce(nation_plus_region, 1) * rand() < 0.5");

assertUpdate("INSERT INTO test_schema_evolution SELECT nationkey, nationkey + regionkey, regionkey FROM nation", 25);
assertUpdate("ALTER TABLE test_schema_evolution ADD COLUMN nation_minus_region BIGINT");
assertUpdate("INSERT INTO test_schema_evolution SELECT nationkey, nationkey + regionkey, nationkey - regionkey, regionkey FROM nation", 25);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import com.google.common.collect.Iterables;
import com.google.common.collect.Maps;
import com.google.common.io.Closer;
import com.google.common.primitives.Ints;
import io.airlift.slice.Slice;
import io.airlift.units.DataSize;
import org.joda.time.DateTimeZone;
Expand All @@ -47,6 +48,7 @@

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
Expand All @@ -68,10 +70,12 @@
import static com.facebook.presto.spi.type.TinyintType.TINYINT;
import static com.facebook.presto.spi.type.Varchars.isVarcharType;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Predicates.not;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.lang.Math.max;
import static java.util.Objects.requireNonNull;

public class OrcSelectiveRecordReader
Expand All @@ -86,24 +90,46 @@ public class OrcSelectiveRecordReader
private final Map<Integer, Type> columnTypes; // key: index into hiveColumnIndices array
private final Object[] constantValues; // aligned with hiveColumnIndices array
private final Function<Block, Block>[] coercers; // aligned with hiveColumnIndices array
private final List<FilterFunction> filterFunctionsWithInputs;

// non-deterministic filter function with no inputs (rand() < 0.1); evaluated before any column is read
private final Optional<FilterFunction> filterFunctionWithoutInput;
private final Map<Integer, Integer> filterFunctionInputMapping; // channel-to-index-into-hiveColumnIndices-array mapping
private final Set<Integer> filterFunctionInputs; // channels
private final Map<Integer, Integer> columnsWithFilterScores; // keys are indices into hiveColumnIndices array; values are filter scores

// Optimal order of stream readers
private int[] streamReaderOrder; // elements are indices into hiveColumnIndices array

private List<FilterFunction>[] filterFunctionsOrder; // aligned with streamReaderOrder order; each filter function is placed
// into a list positioned at the last necessary input
private Set<Integer>[] filterFunctionInputs; // aligned with filterFunctionsOrder

// non-deterministic filter functions with only constant inputs; evaluated before any column is read
private List<FilterFunction> filterFunctionsWithConstantInputs;
private Set<Integer> filterFunctionConstantInputs;

// An immutable list of initial positions; includes all positions: 0,1,2,3,4,..
// This array may grow, but cannot shrink. The values don't change.
private int[] positions;

// Used in applyFilterFunctions; mutable
private int[] outputPositions;

// errors encountered while evaluating filter functions; indices are positions in the batch
// of rows being processed by getNextPage (errors[outputPositions[i]] is valid)
private RuntimeException[] errors;

// temporary array to be used in applyFilterFunctions only; exists solely for the purpose of re-using memory
// indices are positions in a page provided to the filter filters (it contains a subset of rows that passed earlier filters)
private RuntimeException[] tmpErrors;

// flag indicating whether range filter on a constant column is false; no data is read in that case
private boolean constantFilterIsFalse;

// an error occurred while evaluating deterministic filter function with only constant
// inputs; thrown unless other filters eliminate all rows
@Nullable
private RuntimeException constantFilterError;

private int readPositions;

public OrcSelectiveRecordReader(
Expand Down Expand Up @@ -185,15 +211,12 @@ public OrcSelectiveRecordReader(
this.outputColumns = outputColumns.stream().map(zeroBasedIndices::get).collect(toImmutableList());
this.columnTypes = includedColumns.entrySet().stream().collect(toImmutableMap(entry -> zeroBasedIndices.get(entry.getKey()), Map.Entry::getValue));
this.filterFunctionWithoutInput = getFilterFunctionWithoutInputs(filterFunctions);
this.filterFunctionsWithInputs = filterFunctions.stream()
.filter(function -> function.getInputChannels().length > 0)
.collect(toImmutableList());
this.filterFunctionInputMapping = Maps.transformValues(filterFunctionInputMapping, zeroBasedIndices::get);
this.filterFunctionInputs = filterFunctions.stream()

Set<Integer> usedInputChannels = filterFunctions.stream()
.flatMapToInt(function -> Arrays.stream(function.getInputChannels()))
.boxed()
.map(this.filterFunctionInputMapping::get)
.collect(toImmutableSet());
this.filterFunctionInputMapping = Maps.transformValues(Maps.filterKeys(filterFunctionInputMapping, usedInputChannels::contains), zeroBasedIndices::get);
this.columnsWithFilterScores = filters
.entrySet()
.stream()
Expand Down Expand Up @@ -229,19 +252,124 @@ public OrcSelectiveRecordReader(
}
}

if (!evaluateDeterministicFilterFunctionsWithConstantInputs(filterFunctions)) {
constantFilterIsFalse = true;
// No further initialization needed.
return;
}

// Initial order of stream readers is:
// - readers with integer equality
// - readers with integer range / multivalues / inequality
// - readers with filters
// - followed by readers for columns that provide input to filter functions
// - followed by readers for columns that doesn't have any filtering
streamReaderOrder = orderStreamReaders(columnTypes.keySet().stream().filter(index -> this.constantValues[index] == null).collect(toImmutableSet()), columnsWithFilterScores, filterFunctionInputs, columnTypes);
streamReaderOrder = orderStreamReaders(columnTypes.keySet().stream().filter(index -> this.constantValues[index] == null).collect(toImmutableSet()), columnsWithFilterScores, this.filterFunctionInputMapping.keySet(), columnTypes);

List<FilterFunction> filterFunctionsWithInputs = filterFunctions.stream()
.filter(OrcSelectiveRecordReader::hasInputs)
.filter(not(this::allConstantInputs))
.collect(toImmutableList());

// figure out when to evaluate filter functions; a function is ready for evaluation as soon as the last input has been read
filterFunctionsOrder = orderFilterFunctionsWithInputs(streamReaderOrder, filterFunctionsWithInputs, this.filterFunctionInputMapping);
filterFunctionInputs = collectFilterFunctionInputs(filterFunctionsOrder, this.filterFunctionInputMapping);

filterFunctionsWithConstantInputs = filterFunctions.stream()
.filter(not(FilterFunction::isDeterministic))
.filter(OrcSelectiveRecordReader::hasInputs)
.filter(this::allConstantInputs)
.collect(toImmutableList());
filterFunctionConstantInputs = filterFunctionsWithConstantInputs.stream()
.flatMapToInt(function -> Arrays.stream(function.getInputChannels()))
.boxed()
.map(this.filterFunctionInputMapping::get)
.collect(toImmutableSet());
}

private boolean evaluateDeterministicFilterFunctionsWithConstantInputs(List<FilterFunction> filterFunctions)
{
for (FilterFunction function : filterFunctions) {
if (function.isDeterministic() && hasInputs(function) && allConstantInputs(function) && !evaluateDeterministicFilterFunctionWithConstantInputs(function)) {
return false;
}
}
return true;
}

private boolean evaluateDeterministicFilterFunctionWithConstantInputs(FilterFunction function)
{
int[] inputs = function.getInputChannels();
Block[] blocks = new Block[inputs.length];
for (int i = 0; i < inputs.length; i++) {
int columnIndex = filterFunctionInputMapping.get(inputs[i]);
Object constantValue = constantValues[columnIndex];
blocks[i] = RunLengthEncodedBlock.create(columnTypes.get(columnIndex), constantValue == NULL_MARKER ? null : constantValue, 1);
}

initializeTmpErrors(1);
int positionCount = function.filter(new Page(blocks), new int[] {0}, 1, tmpErrors);

if (tmpErrors[0] != null) {
constantFilterError = tmpErrors[0];
}
return positionCount == 1;
}

private static boolean hasInputs(FilterFunction function)
{
return function.getInputChannels().length > 0;
}

private boolean allConstantInputs(FilterFunction function)
{
return Arrays.stream(function.getInputChannels())
.map(filterFunctionInputMapping::get)
.allMatch(columnIndex -> constantValues[columnIndex] != null);
}

private static List<FilterFunction>[] orderFilterFunctionsWithInputs(int[] streamReaderOrder, List<FilterFunction> filterFunctions, Map<Integer, Integer> inputMapping)
{
List<FilterFunction>[] order = new List[streamReaderOrder.length];
for (FilterFunction function : filterFunctions) {
int[] inputs = function.getInputChannels();
int lastIndex = -1;
for (int input : inputs) {
int columnIndex = inputMapping.get(input);
lastIndex = max(lastIndex, Ints.indexOf(streamReaderOrder, columnIndex));
}

verify(lastIndex >= 0);
if (order[lastIndex] == null) {
order[lastIndex] = new ArrayList<>();
}
order[lastIndex].add(function);
}

return order;
}

private static Set<Integer>[] collectFilterFunctionInputs(List<FilterFunction>[] functionsOrder, Map<Integer, Integer> inputMapping)
{
Set<Integer>[] inputs = new Set[functionsOrder.length];
for (int i = 0; i < functionsOrder.length; i++) {
List<FilterFunction> functions = functionsOrder[i];
if (functions != null) {
inputs[i] = functions.stream()
.flatMapToInt(function -> Arrays.stream(function.getInputChannels()))
.boxed()
.map(inputMapping::get)
.collect(toImmutableSet());
}
}

return inputs;
}

private static Optional<FilterFunction> getFilterFunctionWithoutInputs(List<FilterFunction> filterFunctions)
{
List<FilterFunction> functions = filterFunctions.stream()
.filter(function -> function.getInputChannels().length == 0)
.filter(not(OrcSelectiveRecordReader::hasInputs))
.collect(toImmutableList());
if (functions.isEmpty()) {
return Optional.empty();
Expand Down Expand Up @@ -429,20 +557,22 @@ public Page getNextPage()
positionsToRead = outputPositions;
}

boolean filterFunctionsApplied = filterFunctionsWithInputs.isEmpty();
int offset = getNextRowInGroup();
if (!filterFunctionsWithConstantInputs.isEmpty()) {
positionCount = applyFilterFunctions(filterFunctionsWithConstantInputs, filterFunctionConstantInputs, positionsToRead, positionCount);

for (int columnIndex : streamReaderOrder) {
if (!filterFunctionsApplied && !hasAnyFilter(columnIndex)) {
positionCount = applyFilterFunctions(positionsToRead, positionCount);
if (positionCount == 0) {
break;
}

positionsToRead = outputPositions;
filterFunctionsApplied = true;
if (positionCount == 0) {
batchRead(batchSize);
return EMPTY_PAGE;
}

positionsToRead = outputPositions;
}

int offset = getNextRowInGroup();

for (int i = 0; i < streamReaderOrder.length; i++) {
int columnIndex = streamReaderOrder[i];

if (!hasAnyFilter(columnIndex)) {
break;
}
Expand All @@ -455,11 +585,15 @@ public Page getNextPage()

positionsToRead = streamReader.getReadPositions();
verify(positionCount == 1 || positionsToRead[positionCount - 1] - positionsToRead[0] >= positionCount - 1, "positions must monotonically increase");
}

if (positionCount > 0 && !filterFunctionsApplied) {
positionCount = applyFilterFunctions(positionsToRead, positionCount);
positionsToRead = outputPositions;
if (filterFunctionsOrder[i] != null) {
positionCount = applyFilterFunctions(filterFunctionsOrder[i], filterFunctionInputs[i], positionsToRead, positionCount);
if (positionCount == 0) {
break;
}

positionsToRead = outputPositions;
}
}

batchRead(batchSize);
Expand All @@ -468,11 +602,13 @@ public Page getNextPage()
return EMPTY_PAGE;
}

if (filterFunctionsWithInputs.isEmpty() && filterFunctionWithoutInput.isPresent()) {
for (int i = 0; i < positionCount; i++) {
if (errors[positionsToRead[i]] != null) {
throw errors[positionsToRead[i]];
}
if (constantFilterError != null) {
throw constantFilterError;
}

for (int i = 0; i < positionCount; i++) {
if (errors[positionsToRead[i]] != null) {
throw errors[positionsToRead[i]];
}
}

Expand Down Expand Up @@ -517,7 +653,7 @@ private SelectiveStreamReader getStreamReader(int columnIndex)

private boolean hasAnyFilter(int columnIndex)
{
return columnsWithFilterScores.containsKey(columnIndex) || filterFunctionInputs.contains(columnIndex);
return columnsWithFilterScores.containsKey(columnIndex) || filterFunctionInputMapping.containsKey(columnIndex);
}

private void initializePositions(int batchSize)
Expand All @@ -544,7 +680,7 @@ private int applyFilterFunctionWithNoInputs(int positionCount)
return filterFunctionWithoutInput.get().filter(page, outputPositions, positionCount, errors);
}

private int applyFilterFunctions(int[] positions, int positionCount)
private int applyFilterFunctions(List<FilterFunction> filterFunctions, Set<Integer> filterFunctionInputs, int[] positions, int positionCount)
{
BlockLease[] blockLeases = new BlockLease[hiveColumnIndices.length];
Block[] blocks = new Block[hiveColumnIndices.length];
Expand All @@ -562,16 +698,17 @@ private int applyFilterFunctions(int[] positions, int positionCount)
}
}

if (filterFunctionWithoutInput.isPresent()) {
for (int i = 0; i < positionCount; i++) {
errors[i] = errors[positions[i]];
}
initializeTmpErrors(positionCount);
for (int i = 0; i < positionCount; i++) {
tmpErrors[i] = errors[positions[i]];
}

Arrays.fill(errors, null);

try {
initializeOutputPositions(positionCount);

for (FilterFunction function : filterFunctionsWithInputs) {
for (FilterFunction function : filterFunctions) {
int[] inputs = function.getInputChannels();
Block[] inputBlocks = new Block[inputs.length];

Expand All @@ -580,23 +717,18 @@ private int applyFilterFunctions(int[] positions, int positionCount)
}

Page page = new Page(positionCount, inputBlocks);
positionCount = function.filter(page, outputPositions, positionCount, errors);
positionCount = function.filter(page, outputPositions, positionCount, tmpErrors);
if (positionCount == 0) {
break;
}
}

for (int i = 0; i < positionCount; i++) {
if (errors[i] != null) {
throw errors[i];
}
}

// at this point outputPositions are relative to page, e.g. they are indices into positions array
// translate outputPositions to positions relative to the start of the row group,
// e.g. make outputPositions a subset of positions array
for (int i = 0; i < positionCount; i++) {
outputPositions[i] = positions[outputPositions[i]];
errors[outputPositions[i]] = tmpErrors[i];
}
return positionCount;
}
Expand All @@ -609,6 +741,16 @@ private int applyFilterFunctions(int[] positions, int positionCount)
}
}

private void initializeTmpErrors(int positionCount)
{
if (tmpErrors == null || tmpErrors.length < positionCount) {
tmpErrors = new RuntimeException[positionCount];
}
else {
Arrays.fill(tmpErrors, null);
}
}

private void initializeOutputPositions(int positionCount)
{
if (outputPositions == null || outputPositions.length < positionCount) {
Expand Down

0 comments on commit ac93bd6

Please sign in to comment.