Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support table functions in SQL #391

Merged
merged 8 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.datasqrl.calcite;

import static com.datasqrl.plan.validate.ScriptValidator.addError;
import static com.datasqrl.plan.validate.ScriptValidator.isSelfTable;
import static com.datasqrl.plan.validate.ScriptValidator.isVariable;

Expand All @@ -16,6 +17,7 @@
import com.datasqrl.calcite.visitor.SqlNodeVisitor;
import com.datasqrl.calcite.visitor.SqlRelationVisitor;
import com.datasqrl.canonicalizer.ReservedName;
import com.datasqrl.error.ErrorLabel;
import com.datasqrl.plan.hints.TopNHint.Type;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
Expand All @@ -42,6 +44,7 @@
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlJoin;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlLateralOperator;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlOperatorTable;
Expand Down Expand Up @@ -375,6 +378,38 @@ public Result visitSetOperation(SqlCall node, Context context) {
operandResults.get(0).params);
}

@Override
public Result visitCollectTableFunction(SqlCall node, Context context) {
return visitAugmentedTable(node, context);
}

@Override
public Result visitLateralFunction(SqlCall node, Context context) {
return visitAugmentedTable(node, context);
}

@Override
public Result visitUnnestFunction(SqlCall node, Context context) {
return visitAugmentedTable(node, context);
}

public Result visitAugmentedTable(SqlCall node, Context context) {
Result result = SqlNodeVisitor.accept(this, node.getOperandList().get(0), context);
SqlCall call = node.getOperator().createCall(node.getParserPosition(), result.sqlNode);
return new Result(call, result.currentPath, result.pullupColumns, result.tableReferences,
result.condition, result.params);
}

@Override
public Result visitUserDefinedTableFunction(SqlCall node, Context context) {
return new Result(node, List.of(), List.of(), List.of(), Optional.empty(), parameters);
}

@Override
public Result visitCall(SqlCall node, Context context) {
throw new RuntimeException("Expected call");
}

@Value
public static class PullupColumn {
String columnName;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package com.datasqrl.calcite.visitor;

import com.google.common.base.Preconditions;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlJoin;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlLateralOperator;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.SqlUnnestOperator;
import org.apache.calcite.sql.SqlUnresolvedFunction;
import org.apache.calcite.sql.SqrlAssignTimestamp;
import org.apache.calcite.sql.SqrlAssignment;
import org.apache.calcite.sql.SqrlDistinctQuery;
Expand All @@ -17,6 +21,8 @@
import org.apache.calcite.sql.SqrlSqlQuery;
import org.apache.calcite.sql.SqrlStreamQuery;
import org.apache.calcite.sql.StatementVisitor;
import org.apache.calcite.sql.fun.SqlCollectionTableOperator;
import org.apache.calcite.sql.validate.SqlUserDefinedTableFunction;

public abstract class SqlNodeVisitor<R, C> implements
SqlRelationVisitor<R, C>,
Expand Down Expand Up @@ -48,6 +54,7 @@ public static <R, C> R accept(StatementVisitor<R, C> visitor, SqlNode node, C co
}

public static <R, C> R accept(SqlRelationVisitor<R, C> visitor, SqlNode node, C context) {
Preconditions.checkNotNull(node, "Could not rewrite query.");
if (node.getKind() == SqlKind.AS) {
return visitor.visitAliasedRelation((SqlCall) node, context);
} else if (node instanceof SqlIdentifier) {
Expand All @@ -59,8 +66,20 @@ public static <R, C> R accept(SqlRelationVisitor<R, C> visitor, SqlNode node, C
} else if (node instanceof SqlCall
&& SqlKind.SET_QUERY.contains(node.getKind())) {
return visitor.visitSetOperation((SqlCall) node, context);
} else if (node instanceof SqlCall && ((SqlCall) node).getOperator() instanceof SqlCollectionTableOperator) {
return visitor.visitCollectTableFunction((SqlCall) node, context);
} else if (node instanceof SqlCall && ((SqlCall) node).getOperator() instanceof SqlLateralOperator) {
return visitor.visitLateralFunction((SqlCall) node, context);
} else if (node instanceof SqlCall && ((SqlCall) node).getOperator() instanceof SqlUnnestOperator) {
return visitor.visitUnnestFunction((SqlCall) node, context);
} else if (node instanceof SqlCall &&
((SqlCall) node).getOperator() instanceof SqlUserDefinedTableFunction) {
return visitor.visitUserDefinedTableFunction((SqlCall) node, context);
} else if (node instanceof SqlCall &&
((SqlCall) node).getOperator() instanceof SqlUnresolvedFunction) {
return visitor.visitUserDefinedTableFunction((SqlCall) node, context);
} else if (node instanceof SqlCall) {
return visitor.visitTableFunction((SqlCall) node, context);
return visitor.visitCall((SqlCall) node, context);
}
throw new RuntimeException("Unknown sql statement node:" + node);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ public interface SqlRelationVisitor<R, C> {
R visitTable(SqlIdentifier node, C context);
R visitJoin(SqlJoin node, C context);
R visitSetOperation(SqlCall node, C context);
default R visitTableFunction(SqlCall node, C context) {
return null;
}
R visitCollectTableFunction(SqlCall node, C context);
R visitLateralFunction(SqlCall node, C context);
R visitUnnestFunction(SqlCall node, C context);
R visitUserDefinedTableFunction(SqlCall node, C context);
R visitCall(SqlCall node, C context);
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.SqlJoin;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlLateralOperator;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
Expand Down Expand Up @@ -469,6 +470,39 @@ public SqlNode visitSetOperation(SqlCall node, Object context) {
.map(o->SqlNodeVisitor.accept(this, o, context))
.collect(Collectors.toList()));
}

@Override
public SqlNode visitCollectTableFunction(SqlCall node, Object context) {
return visitAugmentedTable(node, context);
}

@Override
public SqlNode visitLateralFunction(SqlCall node, Object context) {
return visitAugmentedTable(node, context);
}

@Override
public SqlNode visitUnnestFunction(SqlCall node, Object context) {
return visitAugmentedTable(node, context);
}

private SqlNode visitAugmentedTable(SqlCall node, Object context) {
SqlNode op = SqlNodeVisitor.accept(this, node.getOperandList().get(0), context);
return node.getOperator().createCall(node.getParserPosition(), op);
}

@Override
public SqlNode visitUserDefinedTableFunction(SqlCall node, Object context) {
List<SqlNode> operands = node.getOperandList().stream()
.map(f->f.accept(rewriteVariables(parameterList, materializeSelf)))
.collect(Collectors.toList());
return node.getOperator().createCall(node.getParserPosition(), operands);
}

@Override
public SqlNode visitCall(SqlCall node, Object context) {
throw addError(ErrorLabel.GENERIC, node, "Unsupported call: %s", node.getOperator().getName());
}
}, query, null);

return Pair.of(parameterList, node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import static org.apache.calcite.sql.SqlUtil.stripAs;

import com.datasqrl.calcite.QueryPlanner;
import com.datasqrl.calcite.SqrlToSql;
import com.datasqrl.calcite.function.SqrlTableMacro;
import com.datasqrl.calcite.schema.PathWalker;
import com.datasqrl.calcite.schema.sql.SqlBuilders.SqlAliasCallBuilder;
Expand All @@ -18,6 +19,7 @@
import com.datasqrl.plan.validate.SqrlToValidatorSql.Context;
import com.datasqrl.plan.validate.SqrlToValidatorSql.Result;
import com.datasqrl.util.CalciteUtil.RelDataTypeFieldBuilder;
import com.google.common.base.Preconditions;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Multimap;
import java.util.ArrayList;
Expand All @@ -36,11 +38,15 @@
import org.apache.calcite.schema.TableFunction;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlJoin;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlLateralOperator;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.SqlSyntax;
import org.apache.calcite.sql.SqrlTableFunctionDef;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
Expand Down Expand Up @@ -77,18 +83,16 @@ public Result visitQuerySpecification(SqlSelect call, Context context) {
if (ident.isStar() && ident.names.size() == 1) {
for (List<String> path : newContext.getAliasPathMap().values()) {
Optional<SqlUserDefinedTableFunction> sqrlTable = planner.getTableFunction(path);
if (sqrlTable.isEmpty()) {
throw addError(errorCollector, ErrorLabel.GENERIC, node, "Could not find table %s",getDisplay(ident));
if (sqrlTable.isPresent()) {
isA.put(context.root, sqrlTable.get().getFunction());
}
isA.put(context.root, sqrlTable.get().getFunction());
}
} else if (ident.isStar() && ident.names.size() == 2) {
List<String> path = newContext.getAliasPath(ident.names.get(0));
Optional<SqlUserDefinedTableFunction> sqrlTable = planner.getTableFunction(path);
if (sqrlTable.isEmpty()) {
throw addError(errorCollector, ErrorLabel.GENERIC, node, "Could not find table %s",getDisplay(ident));
if (sqrlTable.isPresent()) {
isA.put(context.root, sqrlTable.get().getFunction());
}
isA.put(context.root, sqrlTable.get().getFunction());
}
}
}
Expand Down Expand Up @@ -248,6 +252,15 @@ public Result visitTable(SqlIdentifier node, Context context) {
return new Result(call1, pathWalker.getAbsolutePath(), plannerFns);
}

@Override
public Result visitCall(SqlCall node, Context context) {
if (node.getOperator() instanceof SqlLateralOperator) {
return new Result(node, List.of(), List.of());
}
throw addError(errorCollector, ErrorLabel.GENERIC, node, "Call not yet supported %s",
node.getOperator().getName());
}

private String getDisplay(SqlIdentifier node) {
return String.join(".", node.names);
}
Expand Down Expand Up @@ -291,6 +304,42 @@ public Result visitSetOperation(SqlCall node, Context context) {
List.of(), plannerFns);
}

@Override
public Result visitCollectTableFunction(SqlCall node, Context context) {
return visitAugmentedTable(node, context);
}

@Override
public Result visitLateralFunction(SqlCall node, Context context) {
return visitAugmentedTable(node, context);
}

@Override
public Result visitUnnestFunction(SqlCall node, Context context) {
return visitAugmentedTable(node, context);
}

private Result visitAugmentedTable(SqlCall node, Context context) {
Preconditions.checkState(node.getOperandList().size() == 1, "Expected a single table condition (LATERAL, UNNEST, ...)");
Result result = SqlNodeVisitor.accept(this, node.getOperandList().get(0), context);
SqlCall call = node.getOperator().createCall(node.getParserPosition(), result.sqlNode);
//We don't actually fully resolve the function, just check that it exists and let the sql validator do the rest
return new Result(call, result.currentPath, result.fncs);
}

@Override
public Result visitUserDefinedTableFunction(SqlCall node, Context context) {
List<SqlOperator> operators = new ArrayList<>();
planner.getOperatorTable().lookupOperatorOverloads(node.getOperator().getNameAsId(),
SqlFunctionCategory.USER_DEFINED_TABLE_FUNCTION, SqlSyntax.FUNCTION, operators,
planner.getCatalogReader().nameMatcher());

if (operators.isEmpty()) {
throw addError(errorCollector, ErrorLabel.GENERIC, node, "Could not find table function %s",
node.getOperator().getName());
}
return new Result(node, List.of(), List.of());
}

@AllArgsConstructor
public class WalkSubqueries extends SqlShuttle {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ public void add(ResolvedExport export) {
public void addTable(RootSqrlTable root) {
removePrefix(this.pathToAbsolutePathMap.keySet(),root.getName().toNamePath());
plus().add(String.join(".", root.getPath().toStringList()) + "$"
+ sqrlFramework.getUniqueMacroInt().incrementAndGet(), root
);
+ sqrlFramework.getUniqueMacroInt().incrementAndGet(), root);
if (!root.getParameters().isEmpty()) {
plus().add(root.getName().getDisplay(), root);
}
}

private void removePrefix(Set<NamePath> set, NamePath prefix) {
Expand All @@ -107,6 +109,9 @@ public void addRelationship(Relationship relationship) {
this.sysTableToRelationshipMap.put(relationship.getFromTable(), relationship);
plus().add(String.join(".", relationship.getPath().toStringList()) + "$"
+ sqrlFramework.getUniqueMacroInt().incrementAndGet(), relationship);
if (!relationship.getParameters().isEmpty()) {
plus().add(String.join(".", relationship.getPath().toStringList()), relationship);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should call a helper function. We should already have one for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in next commit

}
}

public void addTableMapping(NamePath path, String nameId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,49 @@ public void groupTest() {
+ " GROUP BY p._uuid;");
}

@Test
public void callTableFunction() {
validateScript("IMPORT ecommerce-data.Orders;\n"
+ "X(@id: Int) := SELECT * FROM Orders WHERE id = @id;\n"
+ "X(@id: Int, @customerid: Int) := SELECT * FROM Orders WHERE id = @id AND customerid = @customerid;\n"
+ "Y(@id: Int) := SELECT * FROM TABLE(X(2));\n"
+ "Z(@id: Int) := SELECT * FROM TABLE(X(2, 3));\n");
}

@Test
public void chainedTableFncCallTest() {
validateScript("IMPORT ecommerce-data.Orders;\n"
+ "X(@id: Int) := SELECT * FROM Orders WHERE id = @id;\n"
+ "Y(@id: Int) := SELECT * FROM TABLE(X(@id));\n"
+ "Z := SELECT * FROM TABLE(Y(3));\n");
}

@Test
public void joinTableFncCallTest() {
validateScript("IMPORT ecommerce-data.Orders;\n"
+ "IMPORT ecommerce-data.Product;\n"
+ "Orders.entries.product(@id: Int) := JOIN Product p ON p.productid = @id;\n"
+ "Y(@id: Int) := SELECT * FROM TABLE(`Orders.entries.product`(@id));");
}
@Test
@Disabled
//todo: Illegal use of dynamic param error
public void joinTableFncCall2Test() {
validateScript("IMPORT ecommerce-data.Orders;\n"
+ "IMPORT ecommerce-data.Product;\n"
+ "Orders.entries.product(@id: Int) := JOIN Product p ON p.productid = @id;\n"
+ "Orders.entries.product(@id: Int) := JOIN Product p ON p.productid = @id;\n"
+ "Y(@id: Int) := SELECT * FROM TABLE(`Orders.entries.product`(@id));");
}

@Test
public void lateralJoinTest() {
validateScript("IMPORT ecommerce-data.Orders;\n"
+ "X(@id: Int) := SELECT * FROM Orders WHERE id = @id;\n"
+ "X(@id: Int, @customerid: Int) := SELECT * FROM Orders WHERE id = @id AND customerid = @customerid;\n"
+ "Y(@id: Int) := SELECT * FROM TABLE(X(2)) AS t JOIN LATERAL TABLE(X(t.id, 3));\n");
}

@Test
public void orderTest() {
validateScript("IMPORT ecommerce-data.Orders;"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
>>>orders$2
LogicalTableScan(table=[[orders$1]])

>>>x$1
LogicalProject(_uuid=[$0], _ingest_time=[$1], id=[$2], customerid=[$3], time=[$4], entries=[$5])
LogicalFilter(condition=[=($2, ?0)])
LogicalTableScan(table=[[orders$2]])

>>>x$2
LogicalProject(_uuid=[$0], _ingest_time=[$1], id=[$2], customerid=[$3], time=[$4], entries=[$5])
LogicalFilter(condition=[AND(=($2, ?0), =($3, ?1))])
LogicalTableScan(table=[[orders$2]])

>>>y$1
LogicalProject(_uuid=[$0], _ingest_time=[$1], id=[$2], customerid=[$3], time=[$4], entries=[$5])
LogicalFilter(condition=[=($2, 2)])
LogicalTableScan(table=[[orders$2]])

>>>z$1
LogicalProject(_uuid=[$0], _ingest_time=[$1], id=[$2], customerid=[$3], time=[$4], entries=[$5])
LogicalFilter(condition=[AND(=($2, 2), =($3, 3))])
LogicalTableScan(table=[[orders$2]])

Loading