Skip to content

Commit

Permalink
fix: allow subquery in ad-hoc SQL (WIP) (apache#19242)
Browse files Browse the repository at this point in the history
* allow adhoc subquery

* add config for allow ad hoc subquery

* default to true allow adhoc subquery

* fix test

* Update superset/errors.py

Co-authored-by: Beto Dealmeida <roberto@dealmeida.net>

* Update superset/connectors/sqla/utils.py

Co-authored-by: David Aaron Suddjian <1858430+suddjian@users.noreply.github.com>

* rename and add doc string

* fix for big query test

* Update superset/connectors/sqla/utils.py

Co-authored-by: Beto Dealmeida <roberto@dealmeida.net>

* Apply suggestions from code review

Co-authored-by: Beto Dealmeida <roberto@dealmeida.net>

* add test

* update validate adhoc subquery

Co-authored-by: Beto Dealmeida <roberto@dealmeida.net>
Co-authored-by: David Aaron Suddjian <1858430+suddjian@users.noreply.github.com>
  • Loading branch information
3 people authored Mar 18, 2022
1 parent 48a12ad commit 50902d5
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 2 deletions.
1 change: 1 addition & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]:
"ALLOW_FULL_CSV_EXPORT": False,
"UX_BETA": False,
"GENERIC_CHART_AXES": False,
"ALLOW_ADHOC_SUBQUERY": False,
}

# Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars.
Expand Down
7 changes: 7 additions & 0 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from superset.connectors.sqla.utils import (
get_physical_table_metadata,
get_virtual_table_metadata,
validate_adhoc_subquery,
)
from superset.datasets.models import Dataset as NewDataset
from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampExpression
Expand Down Expand Up @@ -885,6 +886,7 @@ def adhoc_metric_to_sqla(
elif expression_type == utils.AdhocMetricExpressionType.SQL:
tp = self.get_template_processor()
expression = tp.process_template(cast(str, metric["sqlExpression"]))
validate_adhoc_subquery(expression)
sqla_metric = literal_column(expression)
else:
raise QueryObjectValidationError("Adhoc metric expressionType is invalid")
Expand All @@ -908,6 +910,8 @@ def adhoc_column_to_sqla(
expression = col["sqlExpression"]
if template_processor and expression:
expression = template_processor.process_template(expression)
if expression:
validate_adhoc_subquery(expression)
sqla_metric = literal_column(expression)
return self.make_sqla_column_compatible(sqla_metric, label)

Expand Down Expand Up @@ -1166,6 +1170,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
elif selected in columns_by_name:
outer = columns_by_name[selected].get_sqla_col()
else:
validate_adhoc_subquery(selected)
outer = literal_column(f"({selected})")
outer = self.make_sqla_column_compatible(outer, selected)
else:
Expand All @@ -1178,6 +1183,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
select_exprs.append(outer)
elif columns:
for selected in columns:
validate_adhoc_subquery(selected)
select_exprs.append(
columns_by_name[selected].get_sqla_col()
if selected in columns_by_name
Expand Down Expand Up @@ -1389,6 +1395,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
and db_engine_spec.allows_hidden_cc_in_orderby
and col.name in [select_col.name for select_col in select_exprs]
):
validate_adhoc_subquery(str(col.expression))
col = literal_column(col.name)
direction = asc if ascending else desc
qry = qry.order_by(direction(col))
Expand Down
28 changes: 27 additions & 1 deletion superset/connectors/sqla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from contextlib import closing
from typing import Dict, List, Optional, TYPE_CHECKING

import sqlparse
from flask_babel import lazy_gettext as _
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.sql.type_api import TypeEngine
Expand All @@ -28,7 +29,7 @@
)
from superset.models.core import Database
from superset.result_set import SupersetResultSet
from superset.sql_parse import ParsedQuery
from superset.sql_parse import has_table_query, ParsedQuery

if TYPE_CHECKING:
from superset.connectors.sqla.models import SqlaTable
Expand Down Expand Up @@ -119,3 +120,28 @@ def get_virtual_table_metadata(dataset: "SqlaTable") -> List[Dict[str, str]]:
except Exception as ex:
raise SupersetGenericDBErrorException(message=str(ex)) from ex
return cols


def validate_adhoc_subquery(raw_sql: str) -> None:
"""
Check if adhoc SQL contains sub-queries or nested sub-queries with table
:param raw_sql: adhoc sql expression
:raise SupersetSecurityException if sql contains sub-queries or
nested sub-queries with table
"""
# pylint: disable=import-outside-toplevel
from superset import is_feature_enabled

if is_feature_enabled("ALLOW_ADHOC_SUBQUERY"):
return

for statement in sqlparse.parse(raw_sql):
if has_table_query(statement):
raise SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR,
message=_("Custom SQL fields cannot contain sub-queries."),
level=ErrorLevel.ERROR,
)
)
return
3 changes: 3 additions & 0 deletions superset/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class SupersetErrorType(str, Enum):
SQLLAB_TIMEOUT_ERROR = "SQLLAB_TIMEOUT_ERROR"
RESULTS_BACKEND_ERROR = "RESULTS_BACKEND_ERROR"
ASYNC_WORKERS_ERROR = "ASYNC_WORKERS_ERROR"
ADHOC_SUBQUERY_NOT_ALLOWED_ERROR = "ADHOC_SUBQUERY_NOT_ALLOWED_ERROR"

# Generic errors
GENERIC_COMMAND_ERROR = "GENERIC_COMMAND_ERROR"
Expand Down Expand Up @@ -138,10 +139,12 @@ class SupersetErrorType(str, Enum):
1034: _("The port number is invalid."),
1035: _("Failed to start remote query on a worker."),
1036: _("The database was deleted."),
1037: _("Custom SQL fields cannot contain sub-queries."),
}


ERROR_TYPES_TO_ISSUE_CODES_MAPPING = {
SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR: [1037],
SupersetErrorType.BACKEND_TIMEOUT_ERROR: [1000, 1001],
SupersetErrorType.GENERIC_DB_ENGINE_ERROR: [1002],
SupersetErrorType.COLUMN_DOES_NOT_EXIST_ERROR: [1003, 1004],
Expand Down
31 changes: 30 additions & 1 deletion tests/integration_tests/sqla_models_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from superset.constants import EMPTY_STRING, NULL_STRING
from superset.db_engine_specs.bigquery import BigQueryEngineSpec
from superset.db_engine_specs.druid import DruidEngineSpec
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import QueryObjectValidationError, SupersetSecurityException
from superset.models.core import Database
from superset.utils.core import (
AdhocMetricExpressionType,
Expand Down Expand Up @@ -239,6 +239,35 @@ def test_jinja_metrics_and_calc_columns(self, flask_g):
db.session.delete(table)
db.session.commit()

def test_adhoc_metrics_and_calc_columns(self):
base_query_obj = {
"granularity": None,
"from_dttm": None,
"to_dttm": None,
"groupby": ["user", "expr"],
"metrics": [
{
"expressionType": AdhocMetricExpressionType.SQL,
"sqlExpression": "(SELECT (SELECT * from birth_names) "
"from test_validate_adhoc_sql)",
"label": "adhoc_metrics",
}
],
"is_timeseries": False,
"filter": [],
}

table = SqlaTable(
table_name="test_validate_adhoc_sql", database=get_example_database()
)
db.session.commit()

with pytest.raises(SupersetSecurityException):
table.get_sqla_query(**base_query_obj)
# Cleanup
db.session.delete(table)
db.session.commit()

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_where_operators(self):
filters: Tuple[FilterTestCase, ...] = (
Expand Down
2 changes: 2 additions & 0 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,8 @@ def test_sqlparse_issue_652():
("SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar", False),
("SELECT * FROM other_table", True),
("extract(HOUR from from_unixtime(hour_ts)", False),
("(SELECT * FROM table)", True),
("(SELECT COUNT(DISTINCT name) from birth_names)", True),
],
)
def test_has_table_query(sql: str, expected: bool) -> None:
Expand Down

0 comments on commit 50902d5

Please sign in to comment.