diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index af2699a6dd0f3..27dd34a802d9e 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -354,10 +354,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # This set will give the keywords for data limit statements # to consider for the engines with TOP SQL parsing top_keywords: Set[str] = {"TOP"} - # A set of disallowed connection query parameters - disallow_uri_query_params: Set[str] = set() + # A set of disallowed connection query parameters by driver name + disallow_uri_query_params: Dict[str, Set[str]] = {} # A Dict of query parameters that will always be used on every connection - enforce_uri_query_params: Dict[str, Any] = {} + # by driver name + enforce_uri_query_params: Dict[str, Dict[str, Any]] = {} force_column_alias_quotes = False arraysize = 0 @@ -999,6 +1000,7 @@ def extract_errors( def adjust_database_uri( # pylint: disable=unused-argument cls, uri: URL, + connect_args: Dict[str, Any], selected_schema: Optional[str] = None, ) -> Tuple[URL, Dict[str, Any]]: """ @@ -1024,7 +1026,10 @@ def adjust_database_uri( # pylint: disable=unused-argument This is important because DB engine specs can be installed from 3rd party packages. """ - return uri, {**cls.enforce_uri_query_params} + return uri, { + **connect_args, + **cls.enforce_uri_query_params.get(uri.get_driver_name(), {}), + } @classmethod def patch(cls) -> None: @@ -1744,9 +1749,9 @@ def validate_database_uri(cls, sqlalchemy_uri: URL) -> None: :param sqlalchemy_uri: """ - if existing_disallowed := cls.disallow_uri_query_params.intersection( - sqlalchemy_uri.query - ): + if existing_disallowed := cls.disallow_uri_query_params.get( + sqlalchemy_uri.get_driver_name(), set() + ).intersection(sqlalchemy_uri.query): raise ValueError(f"Forbidden query parameter(s): {existing_disallowed}") diff --git a/superset/db_engine_specs/drill.py b/superset/db_engine_specs/drill.py index 756f74e82ac6e..d8a1940007157 100644 --- a/superset/db_engine_specs/drill.py +++ b/superset/db_engine_specs/drill.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple from urllib import parse from sqlalchemy import types @@ -69,11 +69,16 @@ def convert_dttm( return None @classmethod - def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> URL: + def adjust_database_uri( + cls, + uri: URL, + connect_args: Dict[str, Any], + selected_schema: Optional[str] = None, + ) -> Tuple[URL, Dict[str, Any]]: if selected_schema: uri = uri.set(database=parse.quote(selected_schema, safe="")) - return uri + return uri, connect_args @classmethod def get_url_for_impersonation( diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index c049ee652eee4..f07d53518c21e 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -191,7 +191,6 @@ def df_to_sql( raise SupersetException("Append operation not currently supported") if to_sql_kwargs["if_exists"] == "fail": - # Ensure table doesn't already exist. if table.schema: table_exists = not database.get_df( @@ -260,12 +259,15 @@ def convert_dttm( @classmethod def adjust_database_uri( - cls, uri: URL, selected_schema: Optional[str] = None - ) -> URL: + cls, + uri: URL, + connect_args: Dict[str, Any], + selected_schema: Optional[str] = None, + ) -> Tuple[URL, Dict[str, Any]]: if selected_schema: uri = uri.set(database=parse.quote(selected_schema, safe="")) - return uri + return uri, connect_args @classmethod def _extract_error_message(cls, ex: Exception) -> str: diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index 457509f7a7cb2..622e6c985cb5b 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -173,8 +173,14 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): {}, ), } - disallow_uri_query_params = {"local_infile"} - enforce_uri_query_params = {"local_infile": 0} + disallow_uri_query_params = { + "mysqldb": {"local_infile"}, + "mysqlconnector": {"allow_local_infile"}, + } + enforce_uri_query_params = { + "mysqldb": {"local_infile": 0}, + "mysqlconnector": {"allow_local_infile": 0}, + } @classmethod def convert_dttm( @@ -191,11 +197,14 @@ def convert_dttm( @classmethod def adjust_database_uri( - cls, uri: URL, selected_schema: Optional[str] = None + cls, + uri: URL, + connect_args: Dict[str, Any], + selected_schema: Optional[str] = None, ) -> Tuple[URL, Dict[str, Any]]: uri, new_connect_args = super( MySQLEngineSpec, MySQLEngineSpec - ).adjust_database_uri(uri) + ).adjust_database_uri(uri, connect_args) if selected_schema: uri = uri.set(database=parse.quote(selected_schema, safe="")) diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 72931a85b420c..6bd556b79e39e 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -300,8 +300,11 @@ def epoch_to_dttm(cls) -> str: @classmethod def adjust_database_uri( - cls, uri: URL, selected_schema: Optional[str] = None - ) -> URL: + cls, + uri: URL, + connect_args: Dict[str, Any], + selected_schema: Optional[str] = None, + ) -> Tuple[URL, Dict[str, Any]]: database = uri.database if selected_schema and database: selected_schema = parse.quote(selected_schema, safe="") @@ -311,7 +314,7 @@ def adjust_database_uri( database += "/" + selected_schema uri = uri.set(database=database) - return uri + return uri, connect_args @classmethod def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index 419e0a0655fe0..35801fa76845d 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -134,8 +134,11 @@ def get_extra_params(database: "Database") -> Dict[str, Any]: @classmethod def adjust_database_uri( - cls, uri: URL, selected_schema: Optional[str] = None - ) -> URL: + cls, + uri: URL, + connect_args: Dict[str, Any], + selected_schema: Optional[str] = None, + ) -> Tuple[URL, Dict[str, Any]]: database = uri.database if "/" in uri.database: database = uri.database.split("/")[0] @@ -143,7 +146,7 @@ def adjust_database_uri( selected_schema = parse.quote(selected_schema, safe="") uri = uri.set(database=f"{database}/{selected_schema}") - return uri + return uri, connect_args @classmethod def epoch_to_dttm(cls) -> str: @@ -222,7 +225,6 @@ def build_sqlalchemy_uri( Dict[str, Any] ] = None, ) -> str: - return str( URL( "snowflake", diff --git a/superset/models/core.py b/superset/models/core.py index 9c67a2efa6d2b..fce323b13c88e 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -426,7 +426,15 @@ def _get_sqla_engine( ) self.db_engine_spec.validate_database_uri(sqlalchemy_url) - sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema) + params = extra.get("engine_params", {}) + if nullpool: + params["poolclass"] = NullPool + + connect_args = params.get("connect_args", {}) + + sqlalchemy_url, connect_args = self.db_engine_spec.adjust_database_uri( + sqlalchemy_url, connect_args, schema + ) effective_username = self.get_effective_user(sqlalchemy_url) # If using MySQL or Presto for example, will set url.username # If using Hive, will not do anything yet since that relies on a @@ -438,11 +446,6 @@ def _get_sqla_engine( masked_url = self.get_password_masked_url(sqlalchemy_url) logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url)) - params = extra.get("engine_params", {}) - if nullpool: - params["poolclass"] = NullPool - - connect_args = params.get("connect_args", {}) if self.impersonate_user: self.db_engine_spec.update_impersonation_config( connect_args, str(sqlalchemy_url), effective_username diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index 35dbcc0a6bb3a..d5684b1b62109 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -194,7 +194,7 @@ def test_impersonate_user_presto(self, mocked_create_engine): @mock.patch("superset.models.core.create_engine") def test_adjust_engine_params_mysql(self, mocked_create_engine): model = Database( - database_name="test_database", + database_name="test_database1", sqlalchemy_uri="mysql://user:password@localhost", ) model._get_sqla_engine() @@ -203,6 +203,16 @@ def test_adjust_engine_params_mysql(self, mocked_create_engine): assert str(call_args[0][0]) == "mysql://user:password@localhost" assert call_args[1]["connect_args"]["local_infile"] == 0 + model = Database( + database_name="test_database2", + sqlalchemy_uri="mysql+mysqlconnector://user:password@localhost", + ) + model._get_sqla_engine() + call_args = mocked_create_engine.call_args + + assert str(call_args[0][0]) == "mysql+mysqlconnector://user:password@localhost" + assert call_args[1]["connect_args"]["allow_local_infile"] == 0 + @mock.patch("superset.models.core.create_engine") def test_impersonate_user_trino(self, mocked_create_engine): principal_user = security_manager.find_user(username="gamma") diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py index 3a24e1c2dc2d2..a6f0d99e04115 100644 --- a/tests/unit_tests/db_engine_specs/test_mysql.py +++ b/tests/unit_tests/db_engine_specs/test_mysql.py @@ -104,8 +104,11 @@ def test_convert_dttm( "sqlalchemy_uri,error", [ ("mysql://user:password@host/db1?local_infile=1", True), + ("mysql+mysqlconnector://user:password@host/db1?allow_local_infile=1", True), ("mysql://user:password@host/db1?local_infile=0", True), + ("mysql+mysqlconnector://user:password@host/db1?allow_local_infile=0", True), ("mysql://user:password@host/db1", False), + ("mysql+mysqlconnector://user:password@host/db1", False), ], ) def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None: @@ -123,18 +126,43 @@ def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None: "sqlalchemy_uri,connect_args,returns", [ ("mysql://user:password@host/db1", {"local_infile": 1}, {"local_infile": 0}), + ( + "mysql+mysqlconnector://user:password@host/db1", + {"allow_local_infile": 1}, + {"allow_local_infile": 0}, + ), ("mysql://user:password@host/db1", {"local_infile": -1}, {"local_infile": 0}), + ( + "mysql+mysqlconnector://user:password@host/db1", + {"allow_local_infile": -1}, + {"allow_local_infile": 0}, + ), ("mysql://user:password@host/db1", {"local_infile": 0}, {"local_infile": 0}), + ( + "mysql+mysqlconnector://user:password@host/db1", + {"allow_local_infile": 0}, + {"allow_local_infile": 0}, + ), ( "mysql://user:password@host/db1", {"param1": "some_value"}, {"local_infile": 0, "param1": "some_value"}, ), + ( + "mysql+mysqlconnector://user:password@host/db1", + {"param1": "some_value"}, + {"allow_local_infile": 0, "param1": "some_value"}, + ), ( "mysql://user:password@host/db1", {"local_infile": 1, "param1": "some_value"}, {"local_infile": 0, "param1": "some_value"}, ), + ( + "mysql+mysqlconnector://user:password@host/db1", + {"allow_local_infile": 1, "param1": "some_value"}, + {"allow_local_infile": 0, "param1": "some_value"}, + ), ], ) def test_adjust_database_uri( @@ -143,7 +171,9 @@ def test_adjust_database_uri( from superset.db_engine_specs.mysql import MySQLEngineSpec url = make_url(sqlalchemy_uri) - returned_url, returned_connect_args = MySQLEngineSpec.adjust_database_uri(url) + returned_url, returned_connect_args = MySQLEngineSpec.adjust_database_uri( + url, connect_args + ) assert returned_connect_args == returns