Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IGNITE-23183 #4381

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,39 @@ public void testMergeNullCols() {
.check();
}

@Test
public void testMergeWithSubqueryExpression() {
sql("CREATE TABLE t0(ID INT PRIMARY KEY, VAL INT)");
sql("CREATE TABLE t1(ID INT PRIMARY KEY, VAL BIGINT)");

String sql = "MERGE INTO t0 USING t1 ON t0.id = t1.id "
+ "WHEN MATCHED THEN UPDATE SET val = (SELECT val FROM t1 WHERE id > ?)";

sql("INSERT INTO t0 VALUES (1, 0), (2, 0)");
sql("INSERT INTO t1 VALUES (1, -100), (3, 3)");

// sub-query returns no rows.
sql(sql, 3);
assertQuery("SELECT * FROM t0 ORDER BY id")
.returns(1, null)
.returns(2, 0)
.check();

// sub-query returns single row.
sql(sql, 1);
assertQuery("SELECT * FROM t0 ORDER BY id")
.returns(1, 3)
.returns(2, 0)
.check();

// sub-query returns more than one row.
assertThrowsSqlException(
Sql.RUNTIME_ERR,
"Subquery returned more than 1 value",
() -> sql(sql, 0)
);
}

/**
* Test verifies that scan is executed within provided transaction.
*/
Expand Down Expand Up @@ -685,6 +718,37 @@ public void testUpdateAllowsDefault() {
}
}

@Test
public void testUpdateWithSubqueryExpression() {
sql("CREATE TABLE t0(ID INT PRIMARY KEY, VAL INT)");
sql("CREATE TABLE t1(ID INT PRIMARY KEY, VAL BIGINT)");

sql("INSERT INTO t0 VALUES (1, 1), (2, 2)");
sql("INSERT INTO t1 VALUES (1, 1), (2, 2)");

// Sub-query returns no rows.
sql("UPDATE t0 SET val = (SELECT val FROM t1 WHERE id = -42)");
assertQuery("SELECT * FROM t0")
.returns(1, null)
.returns(2, null)
.check();

// Sub-query returns single row.
sql("UPDATE t0 SET val = (SELECT val FROM t1 WHERE id = 2)");
assertQuery("SELECT * FROM t0")
.returns(1, 2)
.returns(2, 2)
.check();

// Sub-query returns more than one row.
//noinspection ThrowableNotThrown
assertThrowsSqlException(
Sql.RUNTIME_ERR,
"Subquery returned more than 1 value",
() -> sql("UPDATE t0 SET val = (SELECT val FROM t1)")
);
}

@Test
public void testDropDefault() {
// SQL Standard 2016 feature F221 - Explicit defaults
Expand Down Expand Up @@ -848,6 +912,35 @@ public void testInsertValueOverflow(String type, long max, long min) {
}
}

@Test
public void testInsertValueWithSubqueryExpression() {
sql("CREATE TABLE t0(ID INT PRIMARY KEY, VAL INT)");
sql("CREATE TABLE t1(ID INT PRIMARY KEY, VAL INT)");

sql("INSERT INTO t1 VALUES (1, 1), (2, 2)");

// Sub-query returns no rows.
sql("INSERT INTO t0 VALUES (1, (SELECT val FROM t1 WHERE id = -42))");
assertQuery("SELECT * FROM t0")
.returns(1, null)
.check();

// Sub-query returns single row.
sql("INSERT INTO t0 VALUES (2, (SELECT val FROM t1 WHERE id = 2))");
assertQuery("SELECT * FROM t0")
.returns(1, null)
.returns(2, 2)
.check();

// Sub-query returns more than one row.
//noinspection ThrowableNotThrown
assertThrowsSqlException(
Sql.RUNTIME_ERR,
"Subquery returned more than 1 value",
() -> sql("INSERT INTO t0 VALUES (2, (SELECT val FROM t1))")
);
}

@ParameterizedTest
@CsvSource(value = {
"id1,id2; id1",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.apache.calcite.runtime.Resources;
import org.apache.calcite.schema.impl.ModifiableViewTable;
import org.apache.calcite.sql.JoinConditionType;
import org.apache.calcite.sql.JoinType;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlBasicCall;
import org.apache.calcite.sql.SqlCall;
Expand Down Expand Up @@ -80,6 +81,7 @@
import org.apache.calcite.sql.util.SqlShuttle;
import org.apache.calcite.sql.validate.AliasNamespace;
import org.apache.calcite.sql.validate.SelectScope;
import org.apache.calcite.sql.validate.SqlNonNullableAccessors;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorException;
import org.apache.calcite.sql.validate.SqlValidatorImpl;
Expand Down Expand Up @@ -394,9 +396,16 @@ private static void syncSelectList(SqlSelect select, SqlUpdate update) {
int startPosition = selectList.size() - sourceExprListSize;

for (var i = 0; i < sourceExprListSize; i++) {
SqlNode sourceExpr = sourceExpressionList.get(i);
SqlNode replacement = sourceExpressionList.get(i);
int position = startPosition + i;
selectList.set(position, sourceExpr);

// This method was introduced to replace an expression with an expression that has the
// required type cast. Therefore, this only applies when the replacement contains SqlBasicCall.
// For example a call with SCALAR_QUERY is only present in sourceSelect, keeping original
// SqlSelect in sourceExpressionList, and we should not make a replacement in this case.
if (replacement instanceof SqlBasicCall) {
selectList.set(position, replacement);
}
}
}

Expand Down Expand Up @@ -460,10 +469,28 @@ protected SqlSelect createSourceSelectForUpdate(SqlUpdate call) {
.map(name -> alias.plus(name, SqlParserPos.ZERO))
.forEach(selectList::add);

int ordinal = 0;
// Force unique aliases to avoid a duplicate for Y with SET X=Y
for (SqlNode exp : call.getSourceExpressionList()) {
selectList.add(SqlValidatorUtil.addAlias(exp, SqlUtil.deriveAliasFromOrdinal(ordinal++)));
final SqlNodeList selectList2 = new SqlNodeList(SqlParserPos.ZERO);

igniteTable.rowTypeForUpdate((IgniteTypeFactory) typeFactory)
.getFieldNames().stream()
// .map(name -> alias.plus(name, SqlParserPos.ZERO))
.map(name -> new SqlIdentifier(name, SqlParserPos.ZERO)) // new SqlIdentifier(name, SqlParserPos.ZERO))
.forEach(selectList2::add);

for (int i = 0; i < call.getSourceExpressionList().size(); i++) {
SqlNode exp = call.getSourceExpressionList().get(i);

String alias0 = SqlUtil.deriveAliasFromOrdinal(i);

SqlIdentifier id = new SqlIdentifier(alias0, SqlParserPos.ZERO);

if (exp instanceof SqlSelect) {
call.getSourceExpressionList().set(i, id);
}

selectList2.add(id); // blahAlias.plus(id.getSimple(), SqlParserPos.ZERO));

selectList.add(SqlValidatorUtil.addAlias(exp, alias0));
}

SqlNode sourceTable = call.getTargetTable();
Expand All @@ -475,8 +502,13 @@ protected SqlSelect createSourceSelectForUpdate(SqlUpdate call) {
call.getAlias().getSimple());
}

return new SqlSelect(SqlParserPos.ZERO, null, selectList, sourceTable,
SqlSelect select1 = new SqlSelect(SqlParserPos.ZERO, null, selectList, sourceTable,
call.getCondition(), null, null, null, null, null, null, null, null);

SqlSelect select2 = new SqlSelect(SqlParserPos.ZERO, null, selectList2, select1,
null, null, null, null, null, null, null, null, null);

return select2;
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -767,7 +799,90 @@ protected SqlNode performUnconditionalRewrites(SqlNode node, boolean underFrom)
}
}

return super.performUnconditionalRewrites(node, underFrom);
SqlNode resNode = super.performUnconditionalRewrites(node, underFrom);

if (resNode instanceof SqlMerge) {
rewriteMergeAgain((SqlMerge) resNode);
// perform additional rewrites.
}

return resNode;
}

private static void rewriteMergeAgain(SqlMerge call) {
SqlNodeList selectList;
SqlUpdate updateStmt = call.getUpdateCall();
if (updateStmt != null) {
// if we have an update statement, just clone the select list
// from the update statement's source since it's the same as
// what we want for the select list of the merge source -- '*'
// followed by the update set expressions
SqlSelect sourceSelect = SqlNonNullableAccessors.getSourceSelect(updateStmt);

selectList = SqlNode.clone(
((SqlSelect) sourceSelect.getFrom()).getSelectList()
);
} else {
// otherwise, just use select *
selectList = new SqlNodeList(SqlParserPos.ZERO);
selectList.add(SqlIdentifier.star(SqlParserPos.ZERO));
}
SqlNode targetTable = call.getTargetTable();
if (call.getAlias() != null) {
targetTable =
SqlValidatorUtil.addAlias(
targetTable,
call.getAlias().getSimple());
}

// Provided there is an insert substatement, the source select for
// the merge is a left outer join between the source in the USING
// clause and the target table; otherwise, the join is just an
// inner join. Need to clone the source table reference in order
// for validation to work
SqlNode sourceTableRef = call.getSourceTableRef();
SqlInsert insertCall = call.getInsertCall();
JoinType joinType = (insertCall == null) ? JoinType.INNER : JoinType.LEFT;
final SqlNode leftJoinTerm = SqlNode.clone(sourceTableRef);
SqlNode outerJoin =
new SqlJoin(SqlParserPos.ZERO,
leftJoinTerm,
SqlLiteral.createBoolean(false, SqlParserPos.ZERO),
joinType.symbol(SqlParserPos.ZERO),
targetTable,
JoinConditionType.ON.symbol(SqlParserPos.ZERO),
call.getCondition());
SqlSelect select =
new SqlSelect(SqlParserPos.ZERO, null, selectList, outerJoin, null,
null, null, null, null, null, null, null, null);
call.setSourceSelect(select);

// Source for the insert call is a select of the source table
// reference with the select list being the value expressions;
// note that the values clause has already been converted to a
// select on the values row constructor; so we need to extract
// that via the from clause on the select
if (insertCall != null) {
SqlCall valuesCall = (SqlCall) insertCall.getSource();
SqlNode rowCallNode = valuesCall.operand(0);

if (rowCallNode instanceof SqlNodeList) {
// already rewritten.
return;
}

SqlCall rowCall = (SqlCall) rowCallNode;

selectList =
new SqlNodeList(
rowCall.getOperandList(),
SqlParserPos.ZERO);
final SqlNode insertSource = SqlNode.clone(sourceTableRef);
select =
new SqlSelect(SqlParserPos.ZERO, null, selectList, insertSource, null,
null, null, null, null, null, null, null, null);
insertCall.setSource(select);
}
}

/** Rewrites JOIN clause if required. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.function.UnaryOperator;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.sql.fun.SqlSingleValueAggFunction;
import org.apache.ignite.internal.sql.engine.framework.TestBuilders.TableBuilder;
import org.apache.ignite.internal.sql.engine.rel.IgniteAggregate;
import org.apache.ignite.internal.sql.engine.rel.agg.IgniteReduceAggregateBase;
Expand Down Expand Up @@ -917,6 +918,36 @@ enum TestCase {
*/
CASE_26A("SELECT val0, val1, COUNT(*) cnt FROM test GROUP BY val0, val1 ORDER BY val1 DESC",
schema(hash(0))),

/**
* Query: SELECT val0 FROM test WHERE val0 = (SELECT val1 FROM test).
*
* <p>Distribution single()
*/
CASE_27("SELECT val0 FROM test WHERE val0 = (SELECT val1 FROM test)", schema(single())),

/**
* Query: INSERT INTO test (id, val0) VALUES (1, (SELECT val1 FROM test)).
*
* <p>Distribution single()
*/
CASE_27A("INSERT INTO test (id, val0) VALUES (1, (SELECT val1 FROM test))", schema(single())),

/**
* Query: UPDATE test set val0 = (SELECT val1 FROM test).
*
* <p>Distribution single()
*/
CASE_27B("UPDATE test set val0 = (SELECT val1 FROM test)", schema(single())),

/**
* Query: MERGE INTO test as t0 USING test as t1 ON t0.id = t1.id
* WHEN MATCHED THEN UPDATE SET val1 = (SELECT val0 FROM test)
*
* <p>Distribution single()
*/
CASE_27C("MERGE INTO test as t0 USING test as t1 ON t0.id = t1.id "
+ "WHEN MATCHED THEN UPDATE SET val1 = (SELECT val0 FROM test)", schema(single())),
;

final String query;
Expand Down Expand Up @@ -1030,6 +1061,12 @@ <T extends RelNode> Predicate<T> hasAggregate() {
return mapNode.or(reduceNode);
}

<T extends RelNode> Predicate<T> hasSingleValueAggregate() {
return (Predicate<T>) isInstanceOf(IgniteAggregate.class)
.and(n -> n.getAggCallList().stream()
.anyMatch(agg -> agg.getAggregation() instanceof SqlSingleValueAggFunction));
}

<T extends RelNode> Predicate<T> hasDistinctAggregate() {
Predicate<T> mapNode = (Predicate<T>) isInstanceOf(IgniteAggregate.class)
.and(n -> n.getAggCallList().stream().anyMatch(NON_NULL_PREDICATE.and(AggregateCall::isDistinct)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,17 @@ public void subqueryWithAggregateInWhereClause() throws Exception {
checkSimpleAggHash(TestCase.CASE_14B);
}

/**
* Validates that the SINGLE_VALUE aggregate is added for a sub-query where a single value is expected.
*/
@Test
public void subqueryWithSingleValueAggregate() throws Exception {
checkSimpleAggSingle(TestCase.CASE_27, hasSingleValueAggregate());
checkSimpleAggSingle(TestCase.CASE_27A, hasSingleValueAggregate());
checkSimpleAggSingle(TestCase.CASE_27B, hasSingleValueAggregate());
checkSimpleAggSingle(TestCase.CASE_27C, hasSingleValueAggregate());
}

/**
* Validates a plan for a query with DISTINCT aggregate in WHERE clause.
*/
Expand Down Expand Up @@ -569,9 +580,13 @@ public void groupsWithOrderBySubsetOfGroupColumnDescending() throws Exception {
}

private void checkSimpleAggSingle(TestCase testCase) throws Exception {
checkSimpleAggSingle(testCase, hasAggregate());
}

private void checkSimpleAggSingle(TestCase testCase, Predicate<IgniteColocatedHashAggregate> aggPredicate) throws Exception {
assertPlan(testCase,
nodeOrAnyChild(isInstanceOf(IgniteColocatedHashAggregate.class)
.and(hasAggregate())
.and(aggPredicate)
.and(input(isTableScan("TEST")))
)
);
Expand Down
Loading