Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: allow db driver distinction on enforced URI params #23769

Merged
merged 1 commit into from
Apr 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# This set will give the keywords for data limit statements
# to consider for the engines with TOP SQL parsing
top_keywords: Set[str] = {"TOP"}
# A set of disallowed connection query parameters
disallow_uri_query_params: Set[str] = set()
# A set of disallowed connection query parameters by driver name
disallow_uri_query_params: Dict[str, Set[str]] = {}
# A Dict of query parameters that will always be used on every connection
enforce_uri_query_params: Dict[str, Any] = {}
# by driver name
enforce_uri_query_params: Dict[str, Dict[str, Any]] = {}

force_column_alias_quotes = False
arraysize = 0
Expand Down Expand Up @@ -1096,7 +1097,10 @@ def adjust_engine_params( # pylint: disable=unused-argument
This is important because DB engine specs can be installed from 3rd party
packages.
"""
return uri, {**connect_args, **cls.enforce_uri_query_params}
return uri, {
**connect_args,
**cls.enforce_uri_query_params.get(uri.get_driver_name(), {}),
}

@classmethod
def patch(cls) -> None:
Expand Down Expand Up @@ -1850,9 +1854,9 @@ def validate_database_uri(cls, sqlalchemy_uri: URL) -> None:
:param sqlalchemy_uri:
"""
if existing_disallowed := cls.disallow_uri_query_params.intersection(
sqlalchemy_uri.query
):
if existing_disallowed := cls.disallow_uri_query_params.get(
sqlalchemy_uri.get_driver_name(), set()
).intersection(sqlalchemy_uri.query):
raise ValueError(f"Forbidden query parameter(s): {existing_disallowed}")


Expand Down
10 changes: 8 additions & 2 deletions superset/db_engine_specs/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,14 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
{},
),
}
disallow_uri_query_params = {"local_infile"}
enforce_uri_query_params = {"local_infile": 0}
disallow_uri_query_params = {
"mysqldb": {"local_infile"},
"mysqlconnector": {"allow_local_infile"},
}
enforce_uri_query_params = {
"mysqldb": {"local_infile": 0},
"mysqlconnector": {"allow_local_infile": 0},
}
Comment on lines -178 to +185
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thinking out loud: would it be convenient to be able to specify globally disallowed params, too? Something like this:

    disallow_uri_query_params = {
        "": {"my_globally_disallowed_param"},
        "mysqldb": {"local_infile"},
        "mysqlconnector": {"allow_local_infile"},
    }

(feel free to ignore this idea, as it looks pretty awful 😆)


@classmethod
def convert_dttm(
Expand Down
12 changes: 11 additions & 1 deletion tests/integration_tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def test_impersonate_user_presto(self, mocked_create_engine):
@mock.patch("superset.models.core.create_engine")
def test_adjust_engine_params_mysql(self, mocked_create_engine):
model = Database(
database_name="test_database",
database_name="test_database1",
sqlalchemy_uri="mysql://user:password@localhost",
)
model._get_sqla_engine()
Expand All @@ -203,6 +203,16 @@ def test_adjust_engine_params_mysql(self, mocked_create_engine):
assert str(call_args[0][0]) == "mysql://user:password@localhost"
assert call_args[1]["connect_args"]["local_infile"] == 0

model = Database(
database_name="test_database2",
sqlalchemy_uri="mysql+mysqlconnector://user:password@localhost",
)
model._get_sqla_engine()
call_args = mocked_create_engine.call_args

assert str(call_args[0][0]) == "mysql+mysqlconnector://user:password@localhost"
assert call_args[1]["connect_args"]["allow_local_infile"] == 0

@mock.patch("superset.models.core.create_engine")
def test_impersonate_user_trino(self, mocked_create_engine):
principal_user = security_manager.find_user(username="gamma")
Expand Down
28 changes: 28 additions & 0 deletions tests/unit_tests/db_engine_specs/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,11 @@ def test_convert_dttm(
"sqlalchemy_uri,error",
[
("mysql://user:password@host/db1?local_infile=1", True),
("mysql+mysqlconnector://user:password@host/db1?allow_local_infile=1", True),
("mysql://user:password@host/db1?local_infile=0", True),
("mysql+mysqlconnector://user:password@host/db1?allow_local_infile=0", True),
("mysql://user:password@host/db1", False),
("mysql+mysqlconnector://user:password@host/db1", False),
],
)
def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
Expand All @@ -123,18 +126,43 @@ def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
"sqlalchemy_uri,connect_args,returns",
[
("mysql://user:password@host/db1", {"local_infile": 1}, {"local_infile": 0}),
(
"mysql+mysqlconnector://user:password@host/db1",
{"allow_local_infile": 1},
{"allow_local_infile": 0},
),
("mysql://user:password@host/db1", {"local_infile": -1}, {"local_infile": 0}),
(
"mysql+mysqlconnector://user:password@host/db1",
{"allow_local_infile": -1},
{"allow_local_infile": 0},
),
("mysql://user:password@host/db1", {"local_infile": 0}, {"local_infile": 0}),
(
"mysql+mysqlconnector://user:password@host/db1",
{"allow_local_infile": 0},
{"allow_local_infile": 0},
),
(
"mysql://user:password@host/db1",
{"param1": "some_value"},
{"local_infile": 0, "param1": "some_value"},
),
(
"mysql+mysqlconnector://user:password@host/db1",
{"param1": "some_value"},
{"allow_local_infile": 0, "param1": "some_value"},
),
(
"mysql://user:password@host/db1",
{"local_infile": 1, "param1": "some_value"},
{"local_infile": 0, "param1": "some_value"},
),
(
"mysql+mysqlconnector://user:password@host/db1",
{"allow_local_infile": 1, "param1": "some_value"},
{"allow_local_infile": 0, "param1": "some_value"},
),
],
)
def test_adjust_engine_params(
Expand Down