Skip to content

Commit

Permalink
Spark 3.3: Iceberg parser should passthrough unsupported procedure to…
Browse files Browse the repository at this point in the history
… delegate (apache#11580)
  • Loading branch information
pan3793 authored Nov 18, 2024
1 parent 209781a commit 568940f
Show file tree
Hide file tree
Showing 13 changed files with 149 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.iceberg.common.DynConstructors
import org.apache.iceberg.spark.ExtendedParser
import org.apache.iceberg.spark.ExtendedParser.RawOrderField
import org.apache.iceberg.spark.Spark3Util
import org.apache.iceberg.spark.procedures.SparkProcedures
import org.apache.iceberg.spark.source.SparkTable
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -194,8 +195,10 @@ class IcebergSparkSqlExtensionsParser(delegate: ParserInterface) extends ParserI
// Strip comments of the form /* ... */. This must come after stripping newlines so that
// comments that span multiple lines are caught.
.replaceAll("/\\*.*?\\*/", " ")
// Strip backtick then `system`.`ancestors_of` changes to system.ancestors_of
.replaceAll("`", "")
.trim()
normalized.startsWith("call") || (
isIcebergProcedure(normalized) || (
normalized.startsWith("alter table") && (
normalized.contains("add partition field") ||
normalized.contains("drop partition field") ||
Expand All @@ -209,6 +212,12 @@ class IcebergSparkSqlExtensionsParser(delegate: ParserInterface) extends ParserI
isSnapshotRefDdl(normalized)))
}

// All builtin Iceberg procedures are under the 'system' namespace
private def isIcebergProcedure(normalized: String): Boolean = {
normalized.startsWith("call") &&
SparkProcedures.names().asScala.map("system." + _).exists(normalized.contains)
}

private def isSnapshotRefDdl(normalized: String): Boolean = {
normalized.contains("create branch") ||
normalized.contains("replace branch") ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,37 @@ public static void stopSpark() {
currentSpark.stop();
}

@Test
public void testDelegateUnsupportedProcedure() {
assertThatThrownBy(() -> parser.parsePlan("CALL cat.d.t()"))
.isInstanceOf(ParseException.class)
.satisfies(
exception -> {
ParseException parseException = (ParseException) exception;
Assert.assertEquals("PARSE_SYNTAX_ERROR", parseException.getErrorClass());
Assert.assertEquals("Syntax error at or near 'CALL'", parseException.message());
});
}

@Test
public void testCallWithBackticks() throws ParseException {
CallStatement call =
(CallStatement) parser.parsePlan("CALL cat.`system`.`rollback_to_snapshot`()");
Assert.assertEquals(
ImmutableList.of("cat", "system", "rollback_to_snapshot"),
JavaConverters.seqAsJavaList(call.name()));
Assert.assertEquals(0, call.args().size());
}

@Test
public void testCallWithPositionalArgs() throws ParseException {
CallStatement call =
(CallStatement) parser.parsePlan("CALL c.n.func(1, '2', 3L, true, 1.0D, 9.0e1, 900e-1BD)");
(CallStatement)
parser.parsePlan(
"CALL c.system.rollback_to_snapshot(1, '2', 3L, true, 1.0D, 9.0e1, 900e-1BD)");
Assert.assertEquals(
ImmutableList.of("c", "n", "func"), JavaConverters.seqAsJavaList(call.name()));
ImmutableList.of("c", "system", "rollback_to_snapshot"),
JavaConverters.seqAsJavaList(call.name()));

Assert.assertEquals(7, call.args().size());

Expand All @@ -94,9 +119,12 @@ public void testCallWithPositionalArgs() throws ParseException {
@Test
public void testCallWithNamedArgs() throws ParseException {
CallStatement call =
(CallStatement) parser.parsePlan("CALL cat.system.func(c1 => 1, c2 => '2', c3 => true)");
(CallStatement)
parser.parsePlan(
"CALL cat.system.rollback_to_snapshot(c1 => 1, c2 => '2', c3 => true)");
Assert.assertEquals(
ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
ImmutableList.of("cat", "system", "rollback_to_snapshot"),
JavaConverters.seqAsJavaList(call.name()));

Assert.assertEquals(3, call.args().size());

Expand All @@ -107,9 +135,11 @@ public void testCallWithNamedArgs() throws ParseException {

@Test
public void testCallWithMixedArgs() throws ParseException {
CallStatement call = (CallStatement) parser.parsePlan("CALL cat.system.func(c1 => 1, '2')");
CallStatement call =
(CallStatement) parser.parsePlan("CALL cat.system.rollback_to_snapshot(c1 => 1, '2')");
Assert.assertEquals(
ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
ImmutableList.of("cat", "system", "rollback_to_snapshot"),
JavaConverters.seqAsJavaList(call.name()));

Assert.assertEquals(2, call.args().size());

Expand All @@ -121,9 +151,11 @@ public void testCallWithMixedArgs() throws ParseException {
public void testCallWithTimestampArg() throws ParseException {
CallStatement call =
(CallStatement)
parser.parsePlan("CALL cat.system.func(TIMESTAMP '2017-02-03T10:37:30.00Z')");
parser.parsePlan(
"CALL cat.system.rollback_to_snapshot(TIMESTAMP '2017-02-03T10:37:30.00Z')");
Assert.assertEquals(
ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
ImmutableList.of("cat", "system", "rollback_to_snapshot"),
JavaConverters.seqAsJavaList(call.name()));

Assert.assertEquals(1, call.args().size());

Expand All @@ -134,9 +166,11 @@ public void testCallWithTimestampArg() throws ParseException {
@Test
public void testCallWithVarSubstitution() throws ParseException {
CallStatement call =
(CallStatement) parser.parsePlan("CALL cat.system.func('${spark.extra.prop}')");
(CallStatement)
parser.parsePlan("CALL cat.system.rollback_to_snapshot('${spark.extra.prop}')");
Assert.assertEquals(
ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
ImmutableList.of("cat", "system", "rollback_to_snapshot"),
JavaConverters.seqAsJavaList(call.name()));

Assert.assertEquals(1, call.args().size());

Expand All @@ -145,30 +179,32 @@ public void testCallWithVarSubstitution() throws ParseException {

@Test
public void testCallParseError() {
assertThatThrownBy(() -> parser.parsePlan("CALL cat.system radish kebab"))
assertThatThrownBy(() -> parser.parsePlan("CALL cat.system.rollback_to_snapshot kebab"))
.as("Should fail with a sensible parse error")
.isInstanceOf(IcebergParseException.class)
.hasMessageContaining("missing '(' at 'radish'");
.hasMessageContaining("missing '(' at 'kebab'");
}

@Test
public void testCallStripsComments() throws ParseException {
List<String> callStatementsWithComments =
Lists.newArrayList(
"/* bracketed comment */ CALL cat.system.func('${spark.extra.prop}')",
"/**/ CALL cat.system.func('${spark.extra.prop}')",
"-- single line comment \n CALL cat.system.func('${spark.extra.prop}')",
"-- multiple \n-- single line \n-- comments \n CALL cat.system.func('${spark.extra.prop}')",
"/* select * from multiline_comment \n where x like '%sql%'; */ CALL cat.system.func('${spark.extra.prop}')",
"/* bracketed comment */ CALL cat.system.rollback_to_snapshot('${spark.extra.prop}')",
"/**/ CALL cat.system.rollback_to_snapshot('${spark.extra.prop}')",
"-- single line comment \n CALL cat.system.rollback_to_snapshot('${spark.extra.prop}')",
"-- multiple \n-- single line \n-- comments \n CALL cat.system.rollback_to_snapshot('${spark.extra.prop}')",
"/* select * from multiline_comment \n where x like '%sql%'; */ CALL cat.system.rollback_to_snapshot('${spark.extra.prop}')",
"/* {\"app\": \"dbt\", \"dbt_version\": \"1.0.1\", \"profile_name\": \"profile1\", \"target_name\": \"dev\", "
+ "\"node_id\": \"model.profile1.stg_users\"} \n*/ CALL cat.system.func('${spark.extra.prop}')",
+ "\"node_id\": \"model.profile1.stg_users\"} \n*/ CALL cat.system.rollback_to_snapshot('${spark.extra.prop}')",
"/* Some multi-line comment \n"
+ "*/ CALL /* inline comment */ cat.system.func('${spark.extra.prop}') -- ending comment",
"CALL -- a line ending comment\n" + "cat.system.func('${spark.extra.prop}')");
+ "*/ CALL /* inline comment */ cat.system.rollback_to_snapshot('${spark.extra.prop}') -- ending comment",
"CALL -- a line ending comment\n"
+ "cat.system.rollback_to_snapshot('${spark.extra.prop}')");
for (String sqlText : callStatementsWithComments) {
CallStatement call = (CallStatement) parser.parsePlan(sqlText);
Assert.assertEquals(
ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
ImmutableList.of("cat", "system", "rollback_to_snapshot"),
JavaConverters.seqAsJavaList(call.name()));

Assert.assertEquals(1, call.args().size());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
import org.apache.spark.sql.catalyst.parser.ParseException;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;

public class TestCherrypickSnapshotProcedure extends SparkExtensionsTestBase {
Expand Down Expand Up @@ -178,8 +179,13 @@ public void testInvalidCherrypickSnapshotCases() {

assertThatThrownBy(() -> sql("CALL %s.custom.cherrypick_snapshot('n', 't', 1L)", catalogName))
.as("Should not resolve procedures in arbitrary namespaces")
.isInstanceOf(NoSuchProcedureException.class)
.hasMessageContaining("not found");
.isInstanceOf(ParseException.class)
.satisfies(
exception -> {
ParseException parseException = (ParseException) exception;
Assert.assertEquals("PARSE_SYNTAX_ERROR", parseException.getErrorClass());
Assert.assertEquals("Syntax error at or near 'CALL'", parseException.message());
});

assertThatThrownBy(() -> sql("CALL %s.system.cherrypick_snapshot('t')", catalogName))
.as("Should reject calls without all required args")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
import org.apache.iceberg.spark.source.SimpleRecord;
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
import org.apache.spark.sql.catalyst.parser.ParseException;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
Expand Down Expand Up @@ -178,8 +178,12 @@ public void testInvalidExpireSnapshotsCases() {

assertThatThrownBy(() -> sql("CALL %s.custom.expire_snapshots('n', 't')", catalogName))
.as("Should not resolve procedures in arbitrary namespaces")
.isInstanceOf(NoSuchProcedureException.class)
.hasMessageContaining("not found");
.satisfies(
exception -> {
ParseException parseException = (ParseException) exception;
Assert.assertEquals("PARSE_SYNTAX_ERROR", parseException.getErrorClass());
Assert.assertEquals("Syntax error at or near 'CALL'", parseException.message());
});

assertThatThrownBy(() -> sql("CALL %s.system.expire_snapshots()", catalogName))
.as("Should reject calls without all required args")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
import org.apache.iceberg.Table;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
import org.apache.spark.sql.catalyst.parser.ParseException;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;

public class TestFastForwardBranchProcedure extends SparkExtensionsTestBase {
Expand Down Expand Up @@ -176,8 +177,13 @@ public void testInvalidFastForwardBranchCases() {
assertThatThrownBy(
() ->
sql("CALL %s.custom.fast_forward('test_table', 'main', 'newBranch')", catalogName))
.isInstanceOf(NoSuchProcedureException.class)
.hasMessage("Procedure custom.fast_forward not found");
.isInstanceOf(ParseException.class)
.satisfies(
exception -> {
ParseException parseException = (ParseException) exception;
Assert.assertEquals("PARSE_SYNTAX_ERROR", parseException.getErrorClass());
Assert.assertEquals("Syntax error at or near 'CALL'", parseException.message());
});

assertThatThrownBy(() -> sql("CALL %s.system.fast_forward('test_table', 'main')", catalogName))
.isInstanceOf(AnalysisException.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
import org.apache.spark.sql.catalyst.parser.ParseException;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;

public class TestPublishChangesProcedure extends SparkExtensionsTestBase {
Expand Down Expand Up @@ -176,8 +177,12 @@ public void testInvalidApplyWapChangesCases() {
assertThatThrownBy(
() -> sql("CALL %s.custom.publish_changes('n', 't', 'not_valid')", catalogName))
.as("Should not resolve procedures in arbitrary namespaces")
.isInstanceOf(NoSuchProcedureException.class)
.hasMessageContaining("not found");
.satisfies(
exception -> {
ParseException parseException = (ParseException) exception;
Assert.assertEquals("PARSE_SYNTAX_ERROR", parseException.getErrorClass());
Assert.assertEquals("Syntax error at or near 'CALL'", parseException.message());
});

assertThatThrownBy(() -> sql("CALL %s.system.publish_changes('t')", catalogName))
.as("Should reject calls without all required args")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
import org.apache.spark.sql.catalyst.parser.ParseException;
import org.junit.After;
Expand Down Expand Up @@ -266,8 +265,12 @@ public void testInvalidRemoveOrphanFilesCases() {

assertThatThrownBy(() -> sql("CALL %s.custom.remove_orphan_files('n', 't')", catalogName))
.as("Should not resolve procedures in arbitrary namespaces")
.isInstanceOf(NoSuchProcedureException.class)
.hasMessageContaining("not found");
.satisfies(
exception -> {
ParseException parseException = (ParseException) exception;
Assert.assertEquals("PARSE_SYNTAX_ERROR", parseException.getErrorClass());
Assert.assertEquals("Syntax error at or near 'CALL'", parseException.message());
});

assertThatThrownBy(() -> sql("CALL %s.system.remove_orphan_files()", catalogName))
.as("Should reject calls without all required args")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
import org.apache.spark.sql.catalyst.parser.ParseException;
import org.junit.After;
import org.junit.Assert;
import org.junit.Assume;
Expand Down Expand Up @@ -566,8 +566,12 @@ public void testInvalidCasesForRewriteDataFiles() {

assertThatThrownBy(() -> sql("CALL %s.custom.rewrite_data_files('n', 't')", catalogName))
.as("Should not resolve procedures in arbitrary namespaces")
.isInstanceOf(NoSuchProcedureException.class)
.hasMessageContaining("not found");
.satisfies(
exception -> {
ParseException parseException = (ParseException) exception;
Assert.assertEquals("PARSE_SYNTAX_ERROR", parseException.getErrorClass());
Assert.assertEquals("Syntax error at or near 'CALL'", parseException.message());
});

assertThatThrownBy(() -> sql("CALL %s.system.rewrite_data_files()", catalogName))
.as("Should reject calls without all required args")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
import org.apache.spark.sql.catalyst.parser.ParseException;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
Expand Down Expand Up @@ -284,8 +284,12 @@ public void testInvalidRewriteManifestsCases() {

assertThatThrownBy(() -> sql("CALL %s.custom.rewrite_manifests('n', 't')", catalogName))
.as("Should not resolve procedures in arbitrary namespaces")
.isInstanceOf(NoSuchProcedureException.class)
.hasMessageContaining("not found");
.satisfies(
exception -> {
ParseException parseException = (ParseException) exception;
Assert.assertEquals("PARSE_SYNTAX_ERROR", parseException.getErrorClass());
Assert.assertEquals("Syntax error at or near 'CALL'", parseException.message());
});

assertThatThrownBy(() -> sql("CALL %s.system.rewrite_manifests()", catalogName))
.as("Should reject calls without all required args")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
import org.apache.spark.sql.catalyst.parser.ParseException;
import org.junit.After;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.Test;

Expand Down Expand Up @@ -261,8 +262,12 @@ public void testInvalidRollbackToSnapshotCases() {

assertThatThrownBy(() -> sql("CALL %s.custom.rollback_to_snapshot('n', 't', 1L)", catalogName))
.as("Should not resolve procedures in arbitrary namespaces")
.isInstanceOf(NoSuchProcedureException.class)
.hasMessageContaining("not found");
.satisfies(
exception -> {
ParseException parseException = (ParseException) exception;
Assert.assertEquals("PARSE_SYNTAX_ERROR", parseException.getErrorClass());
Assert.assertEquals("Syntax error at or near 'CALL'", parseException.message());
});

assertThatThrownBy(() -> sql("CALL %s.system.rollback_to_snapshot('t')", catalogName))
.as("Should reject calls without all required args")
Expand Down
Loading

0 comments on commit 568940f

Please sign in to comment.