diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index a2bf2bc7be9b9..8556f158be667 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -354,10 +354,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # This set will give the keywords for data limit statements # to consider for the engines with TOP SQL parsing top_keywords: Set[str] = {"TOP"} - # A set of disallowed connection query parameters - disallow_uri_query_params: Set[str] = set() + # A set of disallowed connection query parameters by driver name + disallow_uri_query_params: Dict[str, Set[str]] = {} # A Dict of query parameters that will always be used on every connection - enforce_uri_query_params: Dict[str, Any] = {} + # by driver name + enforce_uri_query_params: Dict[str, Dict[str, Any]] = {} force_column_alias_quotes = False arraysize = 0 @@ -1024,7 +1025,10 @@ def adjust_database_uri( # pylint: disable=unused-argument This is important because DB engine specs can be installed from 3rd party packages. """ - return uri, {**connect_args, **cls.enforce_uri_query_params} + return uri, { + **connect_args, + **cls.enforce_uri_query_params.get(uri.get_driver_name(), {}), + } @classmethod def patch(cls) -> None: @@ -1744,9 +1748,9 @@ def validate_database_uri(cls, sqlalchemy_uri: URL) -> None: :param sqlalchemy_uri: """ - if existing_disallowed := cls.disallow_uri_query_params.intersection( - sqlalchemy_uri.query - ): + if existing_disallowed := cls.disallow_uri_query_params.get( + sqlalchemy_uri.get_driver_name(), set() + ).intersection(sqlalchemy_uri.query): raise ValueError(f"Forbidden query parameter(s): {existing_disallowed}") diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index c9bfa73d13a2d..a4d08b89f7899 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -173,8 +173,14 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): {}, ), } - disallow_uri_query_params = {"local_infile"} - enforce_uri_query_params = {"local_infile": 0} + disallow_uri_query_params = { + "mysqldb": {"local_infile"}, + "mysqlconnector": {"allow_local_infile"}, + } + enforce_uri_query_params = { + "mysqldb": {"local_infile": 0}, + "mysqlconnector": {"allow_local_infile": 0}, + } @classmethod def convert_dttm( diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index 35dbcc0a6bb3a..d5684b1b62109 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -194,7 +194,7 @@ def test_impersonate_user_presto(self, mocked_create_engine): @mock.patch("superset.models.core.create_engine") def test_adjust_engine_params_mysql(self, mocked_create_engine): model = Database( - database_name="test_database", + database_name="test_database1", sqlalchemy_uri="mysql://user:password@localhost", ) model._get_sqla_engine() @@ -203,6 +203,16 @@ def test_adjust_engine_params_mysql(self, mocked_create_engine): assert str(call_args[0][0]) == "mysql://user:password@localhost" assert call_args[1]["connect_args"]["local_infile"] == 0 + model = Database( + database_name="test_database2", + sqlalchemy_uri="mysql+mysqlconnector://user:password@localhost", + ) + model._get_sqla_engine() + call_args = mocked_create_engine.call_args + + assert str(call_args[0][0]) == "mysql+mysqlconnector://user:password@localhost" + assert call_args[1]["connect_args"]["allow_local_infile"] == 0 + @mock.patch("superset.models.core.create_engine") def test_impersonate_user_trino(self, mocked_create_engine): principal_user = security_manager.find_user(username="gamma") diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py index fe2da19737516..3a5c161128fb0 100644 --- a/tests/unit_tests/db_engine_specs/test_mysql.py +++ b/tests/unit_tests/db_engine_specs/test_mysql.py @@ -104,8 +104,11 @@ def test_convert_dttm( "sqlalchemy_uri,error", [ ("mysql://user:password@host/db1?local_infile=1", True), + ("mysql+mysqlconnector://user:password@host/db1?allow_local_infile=1", True), ("mysql://user:password@host/db1?local_infile=0", True), + ("mysql+mysqlconnector://user:password@host/db1?allow_local_infile=0", True), ("mysql://user:password@host/db1", False), + ("mysql+mysqlconnector://user:password@host/db1", False), ], ) def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None: @@ -123,18 +126,43 @@ def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None: "sqlalchemy_uri,connect_args,returns", [ ("mysql://user:password@host/db1", {"local_infile": 1}, {"local_infile": 0}), + ( + "mysql+mysqlconnector://user:password@host/db1", + {"allow_local_infile": 1}, + {"allow_local_infile": 0}, + ), ("mysql://user:password@host/db1", {"local_infile": -1}, {"local_infile": 0}), + ( + "mysql+mysqlconnector://user:password@host/db1", + {"allow_local_infile": -1}, + {"allow_local_infile": 0}, + ), ("mysql://user:password@host/db1", {"local_infile": 0}, {"local_infile": 0}), + ( + "mysql+mysqlconnector://user:password@host/db1", + {"allow_local_infile": 0}, + {"allow_local_infile": 0}, + ), ( "mysql://user:password@host/db1", {"param1": "some_value"}, {"local_infile": 0, "param1": "some_value"}, ), + ( + "mysql+mysqlconnector://user:password@host/db1", + {"param1": "some_value"}, + {"allow_local_infile": 0, "param1": "some_value"}, + ), ( "mysql://user:password@host/db1", {"local_infile": 1, "param1": "some_value"}, {"local_infile": 0, "param1": "some_value"}, ), + ( + "mysql+mysqlconnector://user:password@host/db1", + {"allow_local_infile": 1, "param1": "some_value"}, + {"allow_local_infile": 0, "param1": "some_value"}, + ), ], ) def test_adjust_engine_params(