Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restrict UDF functions #2790

Merged
merged 1 commit into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions async-query-core/src/main/antlr/SqlBaseLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ NANOSECOND: 'NANOSECOND';
NANOSECONDS: 'NANOSECONDS';
NATURAL: 'NATURAL';
NO: 'NO';
NONE: 'NONE';
NOT: 'NOT';
NULL: 'NULL';
NULLS: 'NULLS';
Expand Down
52 changes: 39 additions & 13 deletions async-query-core/src/main/antlr/SqlBaseParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ singleCompoundStatement
;

beginEndCompoundBlock
: BEGIN compoundBody END
: beginLabel? BEGIN compoundBody END endLabel?
;

compoundBody
Expand All @@ -61,11 +61,26 @@ compoundBody

compoundStatement
: statement
| setStatementWithOptionalVarKeyword
| beginEndCompoundBlock
;

setStatementWithOptionalVarKeyword
: SET (VARIABLE | VAR)? assignmentList #setVariableWithOptionalKeyword
| SET (VARIABLE | VAR)? LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ
LEFT_PAREN query RIGHT_PAREN #setVariableWithOptionalKeyword
;

singleStatement
: statement SEMICOLON* EOF
: (statement|setResetStatement) SEMICOLON* EOF
;

beginLabel
: multipartIdentifier COLON
;

endLabel
: multipartIdentifier
;

singleExpression
Expand Down Expand Up @@ -174,6 +189,8 @@ statement
| ALTER TABLE identifierReference
(partitionSpec)? SET locationSpec #setTableLocation
| ALTER TABLE identifierReference RECOVER PARTITIONS #recoverPartitions
| ALTER TABLE identifierReference
(clusterBySpec | CLUSTER BY NONE) #alterClusterBy
| DROP TABLE (IF EXISTS)? identifierReference PURGE? #dropTable
| DROP VIEW (IF EXISTS)? identifierReference #dropView
| CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)?
Expand Down Expand Up @@ -202,7 +219,7 @@ statement
identifierReference dataType? variableDefaultExpression? #createVariable
| DROP TEMPORARY VARIABLE (IF EXISTS)? identifierReference #dropVariable
| EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN | COST)?
statement #explain
(statement|setResetStatement) #explain
| SHOW TABLES ((FROM | IN) identifierReference)?
(LIKE? pattern=stringLit)? #showTables
| SHOW TABLE EXTENDED ((FROM | IN) ns=identifierReference)?
Expand Down Expand Up @@ -241,26 +258,29 @@ statement
| (MSCK)? REPAIR TABLE identifierReference
(option=(ADD|DROP|SYNC) PARTITIONS)? #repairTable
| op=(ADD | LIST) identifier .*? #manageResource
| SET COLLATION collationName=identifier #setCollation
| SET ROLE .*? #failNativeCommand
| CREATE INDEX (IF errorCapturingNot EXISTS)? identifier ON TABLE?
identifierReference (USING indexType=identifier)?
LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN
(OPTIONS options=propertyList)? #createIndex
| DROP INDEX (IF EXISTS)? identifier ON TABLE? identifierReference #dropIndex
| unsupportedHiveNativeCommands .*? #failNativeCommand
;

setResetStatement
: SET COLLATION collationName=identifier #setCollation
| SET ROLE .*? #failSetRole
| SET TIME ZONE interval #setTimeZone
| SET TIME ZONE timezone #setTimeZone
| SET TIME ZONE .*? #setTimeZone
| SET (VARIABLE | VAR) assignmentList #setVariable
| SET (VARIABLE | VAR) LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ
LEFT_PAREN query RIGHT_PAREN #setVariable
LEFT_PAREN query RIGHT_PAREN #setVariable
| SET configKey EQ configValue #setQuotedConfiguration
| SET configKey (EQ .*?)? #setConfiguration
| SET .*? EQ configValue #setQuotedConfiguration
| SET .*? #setConfiguration
| RESET configKey #resetQuotedConfiguration
| RESET .*? #resetConfiguration
| CREATE INDEX (IF errorCapturingNot EXISTS)? identifier ON TABLE?
identifierReference (USING indexType=identifier)?
LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN
(OPTIONS options=propertyList)? #createIndex
| DROP INDEX (IF EXISTS)? identifier ON TABLE? identifierReference #dropIndex
| unsupportedHiveNativeCommands .*? #failNativeCommand
;

executeImmediate
Expand Down Expand Up @@ -853,13 +873,17 @@ identifierComment

relationPrimary
: identifierReference temporalClause?
sample? tableAlias #tableName
optionsClause? sample? tableAlias #tableName
| LEFT_PAREN query RIGHT_PAREN sample? tableAlias #aliasedQuery
| LEFT_PAREN relation RIGHT_PAREN sample? tableAlias #aliasedRelation
| inlineTable #inlineTableDefault2
| functionTable #tableValuedFunction
;

optionsClause
: WITH options=propertyList
;

inlineTable
: VALUES expression (COMMA expression)* tableAlias
;
Expand Down Expand Up @@ -1572,6 +1596,7 @@ ansiNonReserved
| NANOSECOND
| NANOSECONDS
| NO
| NONE
| NULLS
| NUMERIC
| OF
Expand Down Expand Up @@ -1920,6 +1945,7 @@ nonReserved
| NANOSECOND
| NANOSECONDS
| NO
| NONE
| NOT
| NULL
| NULLS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.sql.spark.dispatcher;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor;
import org.jetbrains.annotations.NotNull;
Expand Down Expand Up @@ -45,25 +46,50 @@ public DispatchQueryResponse dispatch(
this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
dispatchQueryRequest.getDatasource());

if (LangType.SQL.equals(dispatchQueryRequest.getLangType())
&& SQLQueryUtils.isFlintExtensionQuery(dispatchQueryRequest.getQuery())) {
IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest);
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.indexQueryDetails(indexQueryDetails)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();

return getQueryHandlerForFlintExtensionQuery(dispatchQueryRequest, indexQueryDetails)
.submit(dispatchQueryRequest, context);
} else {
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();
return getDefaultAsyncQueryHandler(dispatchQueryRequest.getAccountId())
.submit(dispatchQueryRequest, context);
if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) {
String query = dispatchQueryRequest.getQuery();

if (SQLQueryUtils.isFlintExtensionQuery(query)) {
return handleFlintExtensionQuery(
dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata);
}

List<String> validationErrors = SQLQueryUtils.validateSparkSqlQuery(query);
if (!validationErrors.isEmpty()) {
throw new IllegalArgumentException(
"Query is not allowed: " + String.join(", ", validationErrors));
}
}
return handleDefaultQuery(dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata);
}

private DispatchQueryResponse handleFlintExtensionQuery(
DispatchQueryRequest dispatchQueryRequest,
AsyncQueryRequestContext asyncQueryRequestContext,
DataSourceMetadata dataSourceMetadata) {
IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest);
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.indexQueryDetails(indexQueryDetails)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();

return getQueryHandlerForFlintExtensionQuery(dispatchQueryRequest, indexQueryDetails)
.submit(dispatchQueryRequest, context);
}

private DispatchQueryResponse handleDefaultQuery(
DispatchQueryRequest dispatchQueryRequest,
AsyncQueryRequestContext asyncQueryRequestContext,
DataSourceMetadata dataSourceMetadata) {

DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();

return getDefaultAsyncQueryHandler(dispatchQueryRequest.getAccountId())
.submit(dispatchQueryRequest, context);
}

private DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.sql.spark.utils;

import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
Expand Down Expand Up @@ -73,6 +75,32 @@ public static boolean isFlintExtensionQuery(String sqlQuery) {
}
}

public static List<String> validateSparkSqlQuery(String sqlQuery) {
vmmusings marked this conversation as resolved.
Show resolved Hide resolved
SparkSqlValidatorVisitor sparkSqlValidatorVisitor = new SparkSqlValidatorVisitor();
SqlBaseParser sqlBaseParser =
new SqlBaseParser(
new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery))));
sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener());
try {
SqlBaseParser.StatementContext statement = sqlBaseParser.statement();
sparkSqlValidatorVisitor.visit(statement);
return sparkSqlValidatorVisitor.getValidationErrors();
} catch (SyntaxCheckException syntaxCheckException) {
return Collections.emptyList();
}
}

private static class SparkSqlValidatorVisitor extends SqlBaseParserBaseVisitor<Void> {

@Getter private final List<String> validationErrors = new ArrayList<>();

@Override
public Void visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) {
validationErrors.add("Creating user-defined functions is not allowed");
return super.visitCreateFunction(ctx);
}
}

public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor<Void> {

@Getter private List<FullyQualifiedTableName> fullyQualifiedTableNames = new LinkedList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@
import com.amazonaws.services.emrserverless.model.GetJobRunResult;
import com.amazonaws.services.emrserverless.model.JobRun;
import com.amazonaws.services.emrserverless.model.JobRunState;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -438,6 +440,77 @@ void testDispatchWithPPLQuery() {
verifyNoInteractions(flintIndexMetadataService);
}

@Test
void testDispatchWithSparkUDFQuery() {
List<String> udfQueries = new ArrayList<>();
udfQueries.add(
"CREATE FUNCTION celsius_to_fahrenheit AS 'org.apache.spark.sql.functions.expr(\"(celsius *"
+ " 9/5) + 32\")'");
udfQueries.add(
"CREATE TEMPORARY FUNCTION square AS 'org.apache.spark.sql.functions.expr(\"num * num\")'");
for (String query : udfQueries) {
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
.thenReturn(dataSourceMetadata);

IllegalArgumentException illegalArgumentException =
Assertions.assertThrows(
IllegalArgumentException.class,
() ->
sparkQueryDispatcher.dispatch(
getBaseDispatchQueryRequestBuilder(query).langType(LangType.SQL).build(),
asyncQueryRequestContext));
Assertions.assertEquals(
"Query is not allowed: Creating user-defined functions is not allowed",
illegalArgumentException.getMessage());
verifyNoInteractions(emrServerlessClient);
verifyNoInteractions(flintIndexMetadataService);
}
}

@Test
void testInvalidSQLQueryDispatchToSpark() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText());
String query = "myselect 1";
String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query);
StartJobRequest expected =
new StartJobRequest(
"TEST_CLUSTER:batch",
null,
EMRS_APPLICATION_ID,
EMRS_EXECUTION_ROLE,
sparkSubmitParameters,
tags,
false,
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
sparkQueryDispatcher.dispatch(
DispatchQueryRequest.builder()
.applicationId(EMRS_APPLICATION_ID)
.query(query)
.datasource(MY_GLUE)
.langType(LangType.SQL)
.executionRoleARN(EMRS_EXECUTION_ROLE)
.clusterName(TEST_CLUSTER_NAME)
.sparkSubmitParameterModifier(sparkSubmitParameterModifier)
.build(),
asyncQueryRequestContext);

verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
verifyNoInteractions(flintIndexMetadataService);
}

@Test
void testDispatchQueryWithoutATableAndDataSourceName() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,25 @@ void testAutoRefresh() {
.autoRefresh());
}

@Test
void testValidateSparkSqlQuery_ValidQuery() {
String validQuery = "SELECT * FROM users WHERE age > 18";
List<String> errors = SQLQueryUtils.validateSparkSqlQuery(validQuery);
assertTrue(errors.isEmpty(), "Valid query should not produce any errors");
}

@Test
void testValidateSparkSqlQuery_InvalidQuery() {
String invalidQuery = "CREATE FUNCTION myUDF AS 'com.example.UDF'";
List<String> errors = SQLQueryUtils.validateSparkSqlQuery(invalidQuery);
assertFalse(errors.isEmpty(), "Invalid query should produce errors");
assertEquals(1, errors.size(), "Should have one error");
assertEquals(
"Creating user-defined functions is not allowed",
errors.get(0),
"Error message should match");
}

@Getter
protected static class IndexQuery {
private String query;
Expand Down
2 changes: 2 additions & 0 deletions docs/user/interfaces/asyncqueryinterface.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ Async Query Creation API
======================================
If security plugin is enabled, this API can only be invoked by users with permission ``cluster:admin/opensearch/ql/async_query/create``.

Limitation: Spark SQL queries that create User-Defined Functions (UDFs) are not allowed.

HTTP URI: ``_plugins/_async_query``

HTTP VERB: ``POST``
Expand Down
Loading