Skip to content

Commit

Permalink
fix: allow db driver distinction on enforced URI params (#23769)
Browse files Browse the repository at this point in the history
(cherry picked from commit 6ae5388)
  • Loading branch information
dpgaspar authored and eschutho committed May 25, 2023
1 parent a564bd2 commit a63fefc
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 10 deletions.
18 changes: 11 additions & 7 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,10 +354,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# This set will give the keywords for data limit statements
# to consider for the engines with TOP SQL parsing
top_keywords: Set[str] = {"TOP"}
# A set of disallowed connection query parameters
disallow_uri_query_params: Set[str] = set()
# A set of disallowed connection query parameters by driver name
disallow_uri_query_params: Dict[str, Set[str]] = {}
# A Dict of query parameters that will always be used on every connection
enforce_uri_query_params: Dict[str, Any] = {}
# by driver name
enforce_uri_query_params: Dict[str, Dict[str, Any]] = {}

force_column_alias_quotes = False
arraysize = 0
Expand Down Expand Up @@ -1024,7 +1025,10 @@ def adjust_database_uri( # pylint: disable=unused-argument
This is important because DB engine specs can be installed from 3rd party
packages.
"""
return uri, {**connect_args, **cls.enforce_uri_query_params}
return uri, {
**connect_args,
**cls.enforce_uri_query_params.get(uri.get_driver_name(), {}),
}

@classmethod
def patch(cls) -> None:
Expand Down Expand Up @@ -1744,9 +1748,9 @@ def validate_database_uri(cls, sqlalchemy_uri: URL) -> None:
:param sqlalchemy_uri:
"""
if existing_disallowed := cls.disallow_uri_query_params.intersection(
sqlalchemy_uri.query
):
if existing_disallowed := cls.disallow_uri_query_params.get(
sqlalchemy_uri.get_driver_name(), set()
).intersection(sqlalchemy_uri.query):
raise ValueError(f"Forbidden query parameter(s): {existing_disallowed}")


Expand Down
10 changes: 8 additions & 2 deletions superset/db_engine_specs/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,14 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
{},
),
}
disallow_uri_query_params = {"local_infile"}
enforce_uri_query_params = {"local_infile": 0}
disallow_uri_query_params = {
"mysqldb": {"local_infile"},
"mysqlconnector": {"allow_local_infile"},
}
enforce_uri_query_params = {
"mysqldb": {"local_infile": 0},
"mysqlconnector": {"allow_local_infile": 0},
}

@classmethod
def convert_dttm(
Expand Down
12 changes: 11 additions & 1 deletion tests/integration_tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def test_impersonate_user_presto(self, mocked_create_engine):
@mock.patch("superset.models.core.create_engine")
def test_adjust_engine_params_mysql(self, mocked_create_engine):
model = Database(
database_name="test_database",
database_name="test_database1",
sqlalchemy_uri="mysql://user:password@localhost",
)
model._get_sqla_engine()
Expand All @@ -203,6 +203,16 @@ def test_adjust_engine_params_mysql(self, mocked_create_engine):
assert str(call_args[0][0]) == "mysql://user:password@localhost"
assert call_args[1]["connect_args"]["local_infile"] == 0

model = Database(
database_name="test_database2",
sqlalchemy_uri="mysql+mysqlconnector://user:password@localhost",
)
model._get_sqla_engine()
call_args = mocked_create_engine.call_args

assert str(call_args[0][0]) == "mysql+mysqlconnector://user:password@localhost"
assert call_args[1]["connect_args"]["allow_local_infile"] == 0

@mock.patch("superset.models.core.create_engine")
def test_impersonate_user_trino(self, mocked_create_engine):
principal_user = security_manager.find_user(username="gamma")
Expand Down
28 changes: 28 additions & 0 deletions tests/unit_tests/db_engine_specs/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,11 @@ def test_convert_dttm(
"sqlalchemy_uri,error",
[
("mysql://user:password@host/db1?local_infile=1", True),
("mysql+mysqlconnector://user:password@host/db1?allow_local_infile=1", True),
("mysql://user:password@host/db1?local_infile=0", True),
("mysql+mysqlconnector://user:password@host/db1?allow_local_infile=0", True),
("mysql://user:password@host/db1", False),
("mysql+mysqlconnector://user:password@host/db1", False),
],
)
def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
Expand All @@ -123,18 +126,43 @@ def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
"sqlalchemy_uri,connect_args,returns",
[
("mysql://user:password@host/db1", {"local_infile": 1}, {"local_infile": 0}),
(
"mysql+mysqlconnector://user:password@host/db1",
{"allow_local_infile": 1},
{"allow_local_infile": 0},
),
("mysql://user:password@host/db1", {"local_infile": -1}, {"local_infile": 0}),
(
"mysql+mysqlconnector://user:password@host/db1",
{"allow_local_infile": -1},
{"allow_local_infile": 0},
),
("mysql://user:password@host/db1", {"local_infile": 0}, {"local_infile": 0}),
(
"mysql+mysqlconnector://user:password@host/db1",
{"allow_local_infile": 0},
{"allow_local_infile": 0},
),
(
"mysql://user:password@host/db1",
{"param1": "some_value"},
{"local_infile": 0, "param1": "some_value"},
),
(
"mysql+mysqlconnector://user:password@host/db1",
{"param1": "some_value"},
{"allow_local_infile": 0, "param1": "some_value"},
),
(
"mysql://user:password@host/db1",
{"local_infile": 1, "param1": "some_value"},
{"local_infile": 0, "param1": "some_value"},
),
(
"mysql+mysqlconnector://user:password@host/db1",
{"allow_local_infile": 1, "param1": "some_value"},
{"allow_local_infile": 0, "param1": "some_value"},
),
],
)
def test_adjust_engine_params(
Expand Down

0 comments on commit a63fefc

Please sign in to comment.