From c4f5ea8044c5ff071b852d61c87d1d04fe7e1a70 Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Fri, 28 Jun 2024 09:33:21 -0700 Subject: [PATCH] Restrict UDF functions Signed-off-by: Vamsi Manohar --- .../src/main/antlr/FlintSparkSqlExtensions.g4 | 2 +- .../src/main/antlr/SqlBaseLexer.g4 | 15 +++ .../src/main/antlr/SqlBaseParser.g4 | 108 +++++++++++++++++- .../dispatcher/SparkQueryDispatcher.java | 4 + .../sql/spark/utils/SQLQueryUtils.java | 22 ++++ .../dispatcher/SparkQueryDispatcherTest.java | 30 +++++ 6 files changed, 175 insertions(+), 6 deletions(-) diff --git a/async-query-core/src/main/antlr/FlintSparkSqlExtensions.g4 b/async-query-core/src/main/antlr/FlintSparkSqlExtensions.g4 index dc097d596d..2e8d634dad 100644 --- a/async-query-core/src/main/antlr/FlintSparkSqlExtensions.g4 +++ b/async-query-core/src/main/antlr/FlintSparkSqlExtensions.g4 @@ -188,7 +188,7 @@ indexColTypeList ; indexColType - : identifier skipType=(PARTITION | VALUE_SET | MIN_MAX | BLOOM_FILTER) + : multipartIdentifier skipType=(PARTITION | VALUE_SET | MIN_MAX | BLOOM_FILTER) (LEFT_PAREN skipParams RIGHT_PAREN)? ; diff --git a/async-query-core/src/main/antlr/SqlBaseLexer.g4 b/async-query-core/src/main/antlr/SqlBaseLexer.g4 index a9705c1733..85a4633e80 100644 --- a/async-query-core/src/main/antlr/SqlBaseLexer.g4 +++ b/async-query-core/src/main/antlr/SqlBaseLexer.g4 @@ -134,6 +134,7 @@ AS: 'AS'; ASC: 'ASC'; AT: 'AT'; AUTHORIZATION: 'AUTHORIZATION'; +BEGIN: 'BEGIN'; BETWEEN: 'BETWEEN'; BIGINT: 'BIGINT'; BINARY: 'BINARY'; @@ -145,6 +146,7 @@ BUCKETS: 'BUCKETS'; BY: 'BY'; BYTE: 'BYTE'; CACHE: 'CACHE'; +CALLED: 'CALLED'; CASCADE: 'CASCADE'; CASE: 'CASE'; CAST: 'CAST'; @@ -171,6 +173,7 @@ COMPENSATION: 'COMPENSATION'; COMPUTE: 'COMPUTE'; CONCATENATE: 'CONCATENATE'; CONSTRAINT: 'CONSTRAINT'; +CONTAINS: 'CONTAINS'; COST: 'COST'; CREATE: 'CREATE'; CROSS: 'CROSS'; @@ -197,10 +200,12 @@ DECIMAL: 'DECIMAL'; DECLARE: 'DECLARE'; DEFAULT: 'DEFAULT'; DEFINED: 'DEFINED'; +DEFINER: 'DEFINER'; DELETE: 'DELETE'; DELIMITED: 'DELIMITED'; DESC: 'DESC'; DESCRIBE: 'DESCRIBE'; +DETERMINISTIC: 'DETERMINISTIC'; DFS: 'DFS'; DIRECTORIES: 'DIRECTORIES'; DIRECTORY: 'DIRECTORY'; @@ -259,6 +264,7 @@ INDEX: 'INDEX'; INDEXES: 'INDEXES'; INNER: 'INNER'; INPATH: 'INPATH'; +INPUT: 'INPUT'; INPUTFORMAT: 'INPUTFORMAT'; INSERT: 'INSERT'; INTERSECT: 'INTERSECT'; @@ -266,10 +272,12 @@ INTERVAL: 'INTERVAL'; INT: 'INT'; INTEGER: 'INTEGER'; INTO: 'INTO'; +INVOKER: 'INVOKER'; IS: 'IS'; ITEMS: 'ITEMS'; JOIN: 'JOIN'; KEYS: 'KEYS'; +LANGUAGE: 'LANGUAGE'; LAST: 'LAST'; LATERAL: 'LATERAL'; LAZY: 'LAZY'; @@ -297,6 +305,7 @@ MILLISECOND: 'MILLISECOND'; MILLISECONDS: 'MILLISECONDS'; MINUTE: 'MINUTE'; MINUTES: 'MINUTES'; +MODIFIES: 'MODIFIES'; MONTH: 'MONTH'; MONTHS: 'MONTHS'; MSCK: 'MSCK'; @@ -341,6 +350,7 @@ PURGE: 'PURGE'; QUARTER: 'QUARTER'; QUERY: 'QUERY'; RANGE: 'RANGE'; +READS: 'READS'; REAL: 'REAL'; RECORDREADER: 'RECORDREADER'; RECORDWRITER: 'RECORDWRITER'; @@ -355,6 +365,8 @@ REPLACE: 'REPLACE'; RESET: 'RESET'; RESPECT: 'RESPECT'; RESTRICT: 'RESTRICT'; +RETURN: 'RETURN'; +RETURNS: 'RETURNS'; REVOKE: 'REVOKE'; RIGHT: 'RIGHT'; RLIKE: 'RLIKE' | 'REGEXP'; @@ -368,6 +380,7 @@ SECOND: 'SECOND'; SECONDS: 'SECONDS'; SCHEMA: 'SCHEMA'; SCHEMAS: 'SCHEMAS'; +SECURITY: 'SECURITY'; SELECT: 'SELECT'; SEMI: 'SEMI'; SEPARATED: 'SEPARATED'; @@ -386,6 +399,8 @@ SOME: 'SOME'; SORT: 'SORT'; SORTED: 'SORTED'; SOURCE: 'SOURCE'; +SPECIFIC: 'SPECIFIC'; +SQL: 'SQL'; START: 'START'; STATISTICS: 'STATISTICS'; STORED: 'STORED'; diff --git a/async-query-core/src/main/antlr/SqlBaseParser.g4 b/async-query-core/src/main/antlr/SqlBaseParser.g4 index 4552c17e0c..54eff14b6d 100644 --- a/async-query-core/src/main/antlr/SqlBaseParser.g4 +++ b/async-query-core/src/main/antlr/SqlBaseParser.g4 @@ -42,6 +42,28 @@ options { tokenVocab = SqlBaseLexer; } public boolean double_quoted_identifiers = false; } +compoundOrSingleStatement + : singleStatement + | singleCompoundStatement + ; + +singleCompoundStatement + : beginEndCompoundBlock SEMICOLON? EOF + ; + +beginEndCompoundBlock + : BEGIN compoundBody END + ; + +compoundBody + : (compoundStatements+=compoundStatement SEMICOLON)* + ; + +compoundStatement + : statement + | beginEndCompoundBlock + ; + singleStatement : statement SEMICOLON* EOF ; @@ -83,13 +105,15 @@ statement (WITH (DBPROPERTIES | PROPERTIES) propertyList))* #createNamespace | ALTER namespace identifierReference SET (DBPROPERTIES | PROPERTIES) propertyList #setNamespaceProperties + | ALTER namespace identifierReference + UNSET (DBPROPERTIES | PROPERTIES) propertyList #unsetNamespaceProperties | ALTER namespace identifierReference SET locationSpec #setNamespaceLocation | DROP namespace (IF EXISTS)? identifierReference (RESTRICT | CASCADE)? #dropNamespace | SHOW namespaces ((FROM | IN) multipartIdentifier)? (LIKE? pattern=stringLit)? #showNamespaces - | createTableHeader (LEFT_PAREN createOrReplaceTableColTypeList RIGHT_PAREN)? tableProvider? + | createTableHeader (LEFT_PAREN colDefinitionList RIGHT_PAREN)? tableProvider? createTableClauses (AS? query)? #createTable | CREATE TABLE (IF errorCapturingNot EXISTS)? target=tableIdentifier @@ -99,7 +123,7 @@ statement createFileFormat | locationSpec | (TBLPROPERTIES tableProps=propertyList))* #createTableLike - | replaceTableHeader (LEFT_PAREN createOrReplaceTableColTypeList RIGHT_PAREN)? tableProvider? + | replaceTableHeader (LEFT_PAREN colDefinitionList RIGHT_PAREN)? tableProvider? createTableClauses (AS? query)? #replaceTable | ANALYZE TABLE identifierReference partitionSpec? COMPUTE STATISTICS @@ -168,6 +192,11 @@ statement | CREATE (OR REPLACE)? TEMPORARY? FUNCTION (IF errorCapturingNot EXISTS)? identifierReference AS className=stringLit (USING resource (COMMA resource)*)? #createFunction + | CREATE (OR REPLACE)? TEMPORARY? FUNCTION (IF errorCapturingNot EXISTS)? + identifierReference LEFT_PAREN parameters=colDefinitionList? RIGHT_PAREN + (RETURNS (dataType | TABLE LEFT_PAREN returnParams=colTypeList RIGHT_PAREN))? + routineCharacteristics + RETURN (query | expression) #createUserDefinedFunction | DROP TEMPORARY? FUNCTION (IF EXISTS)? identifierReference #dropFunction | DECLARE (OR REPLACE)? VARIABLE? identifierReference dataType? variableDefaultExpression? #createVariable @@ -1186,11 +1215,11 @@ colType : colName=errorCapturingIdentifier dataType (errorCapturingNot NULL)? commentSpec? ; -createOrReplaceTableColTypeList - : createOrReplaceTableColType (COMMA createOrReplaceTableColType)* +colDefinitionList + : colDefinition (COMMA colDefinition)* ; -createOrReplaceTableColType +colDefinition : colName=errorCapturingIdentifier dataType colDefinitionOption* ; @@ -1213,6 +1242,46 @@ complexColType : errorCapturingIdentifier COLON? dataType (errorCapturingNot NULL)? commentSpec? ; +routineCharacteristics + : (routineLanguage + | specificName + | deterministic + | sqlDataAccess + | nullCall + | commentSpec + | rightsClause)* + ; + +routineLanguage + : LANGUAGE (SQL | IDENTIFIER) + ; + +specificName + : SPECIFIC specific=errorCapturingIdentifier + ; + +deterministic + : DETERMINISTIC + | errorCapturingNot DETERMINISTIC + ; + +sqlDataAccess + : access=NO SQL + | access=CONTAINS SQL + | access=READS SQL DATA + | access=MODIFIES SQL DATA + ; + +nullCall + : RETURNS NULL ON NULL INPUT + | CALLED ON NULL INPUT + ; + +rightsClause + : SQL SECURITY INVOKER + | SQL SECURITY DEFINER + ; + whenClause : WHEN condition=expression THEN result=expression ; @@ -1360,6 +1429,7 @@ ansiNonReserved | ARRAY | ASC | AT + | BEGIN | BETWEEN | BIGINT | BINARY @@ -1371,6 +1441,7 @@ ansiNonReserved | BY | BYTE | CACHE + | CALLED | CASCADE | CATALOG | CATALOGS @@ -1390,6 +1461,7 @@ ansiNonReserved | COMPENSATION | COMPUTE | CONCATENATE + | CONTAINS | COST | CUBE | CURRENT @@ -1410,10 +1482,12 @@ ansiNonReserved | DECLARE | DEFAULT | DEFINED + | DEFINER | DELETE | DELIMITED | DESC | DESCRIBE + | DETERMINISTIC | DFS | DIRECTORIES | DIRECTORY @@ -1454,13 +1528,16 @@ ansiNonReserved | INDEX | INDEXES | INPATH + | INPUT | INPUTFORMAT | INSERT | INT | INTEGER | INTERVAL + | INVOKER | ITEMS | KEYS + | LANGUAGE | LAST | LAZY | LIKE @@ -1485,6 +1562,7 @@ ansiNonReserved | MILLISECONDS | MINUTE | MINUTES + | MODIFIES | MONTH | MONTHS | MSCK @@ -1518,6 +1596,7 @@ ansiNonReserved | QUARTER | QUERY | RANGE + | READS | REAL | RECORDREADER | RECORDWRITER @@ -1531,6 +1610,8 @@ ansiNonReserved | RESET | RESPECT | RESTRICT + | RETURN + | RETURNS | REVOKE | RLIKE | ROLE @@ -1543,6 +1624,7 @@ ansiNonReserved | SCHEMAS | SECOND | SECONDS + | SECURITY | SEMI | SEPARATED | SERDE @@ -1558,6 +1640,7 @@ ansiNonReserved | SORT | SORTED | SOURCE + | SPECIFIC | START | STATISTICS | STORED @@ -1662,6 +1745,7 @@ nonReserved | ASC | AT | AUTHORIZATION + | BEGIN | BETWEEN | BIGINT | BINARY @@ -1674,6 +1758,7 @@ nonReserved | BY | BYTE | CACHE + | CALLED | CASCADE | CASE | CAST @@ -1700,6 +1785,7 @@ nonReserved | COMPUTE | CONCATENATE | CONSTRAINT + | CONTAINS | COST | CREATE | CUBE @@ -1725,10 +1811,12 @@ nonReserved | DECLARE | DEFAULT | DEFINED + | DEFINER | DELETE | DELIMITED | DESC | DESCRIBE + | DETERMINISTIC | DFS | DIRECTORIES | DIRECTORY @@ -1784,15 +1872,18 @@ nonReserved | INDEX | INDEXES | INPATH + | INPUT | INPUTFORMAT | INSERT | INT | INTEGER | INTERVAL | INTO + | INVOKER | IS | ITEMS | KEYS + | LANGUAGE | LAST | LAZY | LEADING @@ -1819,6 +1910,7 @@ nonReserved | MILLISECONDS | MINUTE | MINUTES + | MODIFIES | MONTH | MONTHS | MSCK @@ -1861,6 +1953,7 @@ nonReserved | QUARTER | QUERY | RANGE + | READS | REAL | RECORDREADER | RECORDWRITER @@ -1875,6 +1968,8 @@ nonReserved | RESET | RESPECT | RESTRICT + | RETURN + | RETURNS | REVOKE | RLIKE | ROLE @@ -1887,6 +1982,7 @@ nonReserved | SCHEMAS | SECOND | SECONDS + | SECURITY | SELECT | SEPARATED | SERDE @@ -1903,6 +1999,8 @@ nonReserved | SORT | SORTED | SOURCE + | SPECIFIC + | SQL | START | STATISTICS | STORED diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 5facdee567..6faccf7e73 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -57,6 +57,10 @@ public DispatchQueryResponse dispatch( return getQueryHandlerForFlintExtensionQuery(indexQueryDetails) .submit(dispatchQueryRequest, context); } else { + if (LangType.SQL.equals(dispatchQueryRequest.getLangType()) + && !SQLQueryUtils.isSparkSqlQueryAllowed(dispatchQueryRequest.getQuery())) { + throw new IllegalArgumentException("Query is not allowed as it contains function creation"); + } DispatchQueryContext context = getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) .asyncQueryRequestContext(asyncQueryRequestContext) diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java index 9dfe30b4b5..97b74044aa 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java @@ -71,6 +71,28 @@ public static boolean isFlintExtensionQuery(String sqlQuery) { } } + public static boolean isSparkSqlQueryAllowed(String sqlQuery) { + SparkSqlValidatorVisitor sparkSqlValidatorVisitor = new SparkSqlValidatorVisitor(); + SqlBaseParser sqlBaseParser = + new SqlBaseParser( + new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery)))); + sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener()); + SqlBaseParser.StatementContext statement = sqlBaseParser.statement(); + sparkSqlValidatorVisitor.visit(statement); + return sparkSqlValidatorVisitor.getIsQueryAllowed(); + } + + public static class SparkSqlValidatorVisitor extends SqlBaseParserBaseVisitor { + + @Getter private Boolean isQueryAllowed = Boolean.TRUE; + + @Override + public Boolean visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) { + isQueryAllowed = Boolean.FALSE; + return super.visitCreateFunction(ctx); + } + } + public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor { @Getter private FullyQualifiedTableName fullyQualifiedTableName; diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index d57284b9ca..9d713f147e 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -40,9 +40,11 @@ import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobRun; import com.amazonaws.services.emrserverless.model.JobRunState; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Optional; import org.json.JSONObject; @@ -497,6 +499,34 @@ void testDispatchWithPPLQuery() { verifyNoInteractions(flintIndexMetadataService); } + @Test + void testDispatchWithSparkUDFQuery() { + List udfQueries = new ArrayList<>(); + udfQueries.add( + "CREATE FUNCTION celsius_to_fahrenheit AS 'org.apache.spark.sql.functions.expr(\"(celsius *" + + " 9/5) + 32\")'"); + udfQueries.add( + "CREATE TEMPORARY FUNCTION square AS 'org.apache.spark.sql.functions.expr(\"num * num\")'"); + for (String query : udfQueries) { + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + .thenReturn(dataSourceMetadata); + + IllegalArgumentException illegalArgumentException = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> + sparkQueryDispatcher.dispatch( + getBaseDispatchQueryRequestBuilder(query).langType(LangType.SQL).build(), + asyncQueryRequestContext)); + Assertions.assertEquals( + "Query is not allowed as it contains function creation", + illegalArgumentException.getMessage()); + verifyNoInteractions(emrServerlessClient); + verifyNoInteractions(flintIndexMetadataService); + } + } + @Test void testDispatchQueryWithoutATableAndDataSourceName() { when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);