Skip to content

Commit

Permalink
Improve IS NULL pushdown for ClickHouse complex expression
Browse files Browse the repository at this point in the history
  • Loading branch information
ssheikin authored and Praveen2112 committed Oct 17, 2024
1 parent 9600eba commit 5936d81
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,9 @@ public ClickHouseClient(
.add(new RewriteStringComparison())
.add(new RewriteStringIn())
.add(new RewriteLike())
.map("$not($is_null(value))").to("value IS NOT NULL")
.map("$not(value: boolean)").to("NOT value")
.map("$is_null(value)").to("value IS NULL")
.build();
this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>(
this.connectorExpressionRewriter,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.clickhouse;

import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.query.QueryAssertions;
import io.trino.sql.query.QueryAssertions.QueryAssert;
import io.trino.testing.QueryRunner;
import io.trino.testing.datatype.ColumnSetup;
import io.trino.testing.datatype.DataSetup;
import io.trino.testing.sql.TemporaryRelation;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.joining;
import static org.assertj.core.api.Assertions.assertThat;

public final class NullPushdownDataTypeTest
{
private final List<TestCase> testCases = new ArrayList<>();
private TestCase specialColumn;
private final boolean connectorExpressionOnly;

private NullPushdownDataTypeTest(boolean connectorExpressionOnly)
{
this.connectorExpressionOnly = connectorExpressionOnly;
}

public static NullPushdownDataTypeTest create()
{
return new NullPushdownDataTypeTest(false);
}

public static NullPushdownDataTypeTest connectorExpressionOnly()
{
return new NullPushdownDataTypeTest(true);
}

public NullPushdownDataTypeTest addSpecialColumn(String inputType, String inputLiteral, String expectedLiteral)
{
checkState(specialColumn == null, "Special column already set");
checkArgument(!"NULL".equalsIgnoreCase(inputLiteral), "Special column should not be NULL");
specialColumn = new TestCase(inputType, inputLiteral, Optional.of(expectedLiteral));
return this;
}

public NullPushdownDataTypeTest addTestCase(String inputType)
{
testCases.add(new TestCase(inputType, "NULL", Optional.empty()));
return this;
}

public NullPushdownDataTypeTest execute(QueryRunner queryRunner, DataSetup dataSetup)
{
return execute(queryRunner, queryRunner.getDefaultSession(), dataSetup);
}

public NullPushdownDataTypeTest execute(QueryRunner queryRunner, Session session, DataSetup dataSetup)
{
checkState(specialColumn != null, "Null pushdown test requires special column");
checkState(!testCases.isEmpty(), "No test cases");
List<ColumnSetup> columns = ImmutableList.<ColumnSetup>builder()
.add(specialColumn)
.addAll(testCases)
.build();
try (TemporaryRelation temporaryRelation = dataSetup.setupTemporaryRelation(columns)) {
verifyPredicate(queryRunner, session, temporaryRelation, true, false, !connectorExpressionOnly);
verifyPredicate(queryRunner, session, temporaryRelation, true, true, true);
verifyPredicate(queryRunner, session, temporaryRelation, false, false, !connectorExpressionOnly);
verifyPredicate(queryRunner, session, temporaryRelation, false, true, true);
}
return this;
}

private void verifyPredicate(QueryRunner queryRunner, Session session, TemporaryRelation temporaryRelation, boolean isNull, boolean connectorExpression, boolean expectPushdown)
{
String specialColumnName = "col_0";
String withConnectorExpression = connectorExpression ? " OR %s IS NULL".formatted(specialColumnName) : "";

String queryWithAll = "SELECT " + specialColumnName + " FROM " + temporaryRelation.getName() + " WHERE " +
IntStream.range(0, testCases.size())
.mapToObj(column -> getPredicate(column, isNull))
.collect(joining(" AND "))
+ withConnectorExpression;

// Closing QueryAssertions would close the QueryRunner
QueryAssertions queryAssertions = new QueryAssertions(queryRunner);
try {
assertPushdown(expectPushdown,
assertResult(isNull ? specialColumn.expectedLiteral() : Optional.empty(),
assertThat(queryAssertions.query(session, queryWithAll))));
}
catch (AssertionError e) {
// if failed - identify exact column which caused the failure
for (int column = 0; column < testCases.size(); column++) {
String queryWithSingleColumnPredicate = "SELECT " + specialColumnName + " FROM " + temporaryRelation.getName() + " WHERE " + getPredicate(column, isNull) + withConnectorExpression;
assertPushdown(expectPushdown,
assertResult(isNull ? specialColumn.expectedLiteral() : Optional.empty(),
assertThat(queryAssertions.query(session, queryWithSingleColumnPredicate))));
}
throw new IllegalStateException("Single column assertion should fail for at least one column, if query of all column failed", e);
}
}

private static String getPredicate(int column, boolean isNull)
{
String columnName = "col_" + (1 + column);
return isNull
? columnName + " IS NULL"
: columnName + " IS NOT NULL";
}

private static QueryAssert assertResult(Optional<String> value, QueryAssert assertion)
{
return value.isPresent()
? assertion.matches("VALUES %s".formatted(value.get()))
: assertion.returnsEmptyResult();
}

private static QueryAssert assertPushdown(boolean expectPushdown, QueryAssert assertion)
{
return expectPushdown
? assertion.isFullyPushedDown()
: assertion.isNotFullyPushedDown(FilterNode.class);
}

private record TestCase(
String declaredType,
String inputLiteral,
Optional<String> expectedLiteral)
implements ColumnSetup
{
private TestCase
{
requireNonNull(declaredType, "declaredType is null");
requireNonNull(inputLiteral, "inputLiteral is null");
requireNonNull(expectedLiteral, "expectedLiteral is null");
}

@Override
public Optional<String> getDeclaredType()
{
return Optional.of(declaredType);
}

@Override
public String getInputLiteral()
{
return inputLiteral;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import io.trino.testing.MaterializedResult;
import io.trino.testing.QueryRunner;
import io.trino.testing.TestingConnectorBehavior;
import io.trino.testing.datatype.CreateAndInsertDataSetup;
import io.trino.testing.datatype.DataSetup;
import io.trino.testing.sql.SqlExecutor;
import io.trino.testing.sql.TestTable;
import org.junit.jupiter.api.Disabled;
Expand All @@ -37,6 +39,7 @@
import java.util.Optional;
import java.util.OptionalInt;

import static io.trino.plugin.clickhouse.ClickHouseSessionProperties.MAP_STRING_AS_VARCHAR;
import static io.trino.plugin.clickhouse.ClickHouseTableProperties.ENGINE_PROPERTY;
import static io.trino.plugin.clickhouse.ClickHouseTableProperties.ORDER_BY_PROPERTY;
import static io.trino.plugin.clickhouse.ClickHouseTableProperties.PARTITION_BY_PROPERTY;
Expand Down Expand Up @@ -1086,6 +1089,61 @@ private void assertLike(boolean isPositive, TestTable table, String withConnecto
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '%$_%' ESCAPE '$'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class);
}

@Test
public void testIsNull()
{
Session mapStringAsVarbinary = Session.builder(getSession())
.setCatalogSessionProperty("clickhouse", MAP_STRING_AS_VARCHAR, "false")
.build();

NullPushdownDataTypeTest.connectorExpressionOnly()
.addSpecialColumn("String", "'z'", "CAST('z' AS varchar)")
.addTestCase("Nullable(real)")
.addTestCase("Nullable(decimal(3, 1))")
.addTestCase("Nullable(decimal(30, 5))")
.execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_is_null"));

NullPushdownDataTypeTest.connectorExpressionOnly()
.addSpecialColumn("String", "'z'", "CAST('z' AS varbinary)")
.addTestCase("Nullable(char(10))")
.addTestCase("LowCardinality(Nullable(char(10)))")
.addTestCase("Nullable(FixedString(10))")
.addTestCase("LowCardinality(Nullable(FixedString(10)))")
.addTestCase("Nullable(varchar(30))")
.addTestCase("LowCardinality(Nullable(varchar(30)))")
.addTestCase("Nullable(String)")
.addTestCase("LowCardinality(Nullable(String))")
.execute(getQueryRunner(), mapStringAsVarbinary, clickhouseCreateAndInsert("tpch.test_is_null"));

NullPushdownDataTypeTest.create()
.addSpecialColumn("String", "'z'", "CAST('z' AS varchar)")
.addTestCase("Nullable(tinyint)")
.addTestCase("Nullable(smallint)")
.addTestCase("Nullable(integer)")
.addTestCase("Nullable(bigint)")
.addTestCase("Nullable(UInt8)")
.addTestCase("Nullable(UInt16)")
.addTestCase("Nullable(UInt32)")
.addTestCase("Nullable(UInt64)")
.addTestCase("Nullable(double)")
.addTestCase("Nullable(char(10))")
.addTestCase("LowCardinality(Nullable(char(10)))")
.addTestCase("Nullable(FixedString(10))")
.addTestCase("LowCardinality(Nullable(FixedString(10)))")
.addTestCase("Nullable(varchar(30))")
.addTestCase("LowCardinality(Nullable(varchar(30)))")
.addTestCase("Nullable(String)")
.addTestCase("LowCardinality(Nullable(String))")
.addTestCase("Nullable(date)")
.addTestCase("Nullable(timestamp)")
.addTestCase("Nullable(datetime)")
.addTestCase("Nullable(datetime('UTC'))")
.addTestCase("Nullable(UUID)")
.addTestCase("Nullable(IPv4)")
.addTestCase("Nullable(IPv6)")
.execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_is_null"));
}

@Test
@Override // Override because ClickHouse doesn't follow SQL standard syntax
public void testExecuteProcedure()
Expand Down Expand Up @@ -1154,4 +1212,9 @@ private Map<String, String> getTableProperties(String schemaName, String tableNa
return properties.buildOrThrow();
}
}

private DataSetup clickhouseCreateAndInsert(String tableNamePrefix)
{
return new CreateAndInsertDataSetup(new ClickHouseSqlExecutor(onRemoteDatabase()), tableNamePrefix);
}
}

0 comments on commit 5936d81

Please sign in to comment.