Skip to content

Commit

Permalink
fix: allow db driver distinction on enforced URI params (#23769)
Browse files Browse the repository at this point in the history
  • Loading branch information
dpgaspar authored and eschutho committed Jun 2, 2023
1 parent dc5bed4 commit 2b3aa09
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 33 deletions.
19 changes: 12 additions & 7 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,10 +354,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# This set will give the keywords for data limit statements
# to consider for the engines with TOP SQL parsing
top_keywords: Set[str] = {"TOP"}
# A set of disallowed connection query parameters
disallow_uri_query_params: Set[str] = set()
# A set of disallowed connection query parameters by driver name
disallow_uri_query_params: Dict[str, Set[str]] = {}
# A Dict of query parameters that will always be used on every connection
enforce_uri_query_params: Dict[str, Any] = {}
# by driver name
enforce_uri_query_params: Dict[str, Dict[str, Any]] = {}

force_column_alias_quotes = False
arraysize = 0
Expand Down Expand Up @@ -999,6 +1000,7 @@ def extract_errors(
def adjust_database_uri( # pylint: disable=unused-argument
cls,
uri: URL,
connect_args: Dict[str, Any],
selected_schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
"""
Expand All @@ -1024,7 +1026,10 @@ def adjust_database_uri( # pylint: disable=unused-argument
This is important because DB engine specs can be installed from 3rd party
packages.
"""
return uri, {**cls.enforce_uri_query_params}
return uri, {
**connect_args,
**cls.enforce_uri_query_params.get(uri.get_driver_name(), {}),
}

@classmethod
def patch(cls) -> None:
Expand Down Expand Up @@ -1744,9 +1749,9 @@ def validate_database_uri(cls, sqlalchemy_uri: URL) -> None:
:param sqlalchemy_uri:
"""
if existing_disallowed := cls.disallow_uri_query_params.intersection(
sqlalchemy_uri.query
):
if existing_disallowed := cls.disallow_uri_query_params.get(
sqlalchemy_uri.get_driver_name(), set()
).intersection(sqlalchemy_uri.query):
raise ValueError(f"Forbidden query parameter(s): {existing_disallowed}")


Expand Down
11 changes: 8 additions & 3 deletions superset/db_engine_specs/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple
from urllib import parse

from sqlalchemy import types
Expand Down Expand Up @@ -69,11 +69,16 @@ def convert_dttm(
return None

@classmethod
def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> URL:
def adjust_database_uri(
cls,
uri: URL,
connect_args: Dict[str, Any],
selected_schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
if selected_schema:
uri = uri.set(database=parse.quote(selected_schema, safe=""))

return uri
return uri, connect_args

@classmethod
def get_url_for_impersonation(
Expand Down
10 changes: 6 additions & 4 deletions superset/db_engine_specs/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ def df_to_sql(
raise SupersetException("Append operation not currently supported")

if to_sql_kwargs["if_exists"] == "fail":

# Ensure table doesn't already exist.
if table.schema:
table_exists = not database.get_df(
Expand Down Expand Up @@ -260,12 +259,15 @@ def convert_dttm(

@classmethod
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> URL:
cls,
uri: URL,
connect_args: Dict[str, Any],
selected_schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
if selected_schema:
uri = uri.set(database=parse.quote(selected_schema, safe=""))

return uri
return uri, connect_args

@classmethod
def _extract_error_message(cls, ex: Exception) -> str:
Expand Down
17 changes: 13 additions & 4 deletions superset/db_engine_specs/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,14 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
{},
),
}
disallow_uri_query_params = {"local_infile"}
enforce_uri_query_params = {"local_infile": 0}
disallow_uri_query_params = {
"mysqldb": {"local_infile"},
"mysqlconnector": {"allow_local_infile"},
}
enforce_uri_query_params = {
"mysqldb": {"local_infile": 0},
"mysqlconnector": {"allow_local_infile": 0},
}

@classmethod
def convert_dttm(
Expand All @@ -191,11 +197,14 @@ def convert_dttm(

@classmethod
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
cls,
uri: URL,
connect_args: Dict[str, Any],
selected_schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
uri, new_connect_args = super(
MySQLEngineSpec, MySQLEngineSpec
).adjust_database_uri(uri)
).adjust_database_uri(uri, connect_args)
if selected_schema:
uri = uri.set(database=parse.quote(selected_schema, safe=""))

Expand Down
9 changes: 6 additions & 3 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,11 @@ def epoch_to_dttm(cls) -> str:

@classmethod
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> URL:
cls,
uri: URL,
connect_args: Dict[str, Any],
selected_schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
database = uri.database
if selected_schema and database:
selected_schema = parse.quote(selected_schema, safe="")
Expand All @@ -311,7 +314,7 @@ def adjust_database_uri(
database += "/" + selected_schema
uri = uri.set(database=database)

return uri
return uri, connect_args

@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
Expand Down
10 changes: 6 additions & 4 deletions superset/db_engine_specs/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,19 @@ def get_extra_params(database: "Database") -> Dict[str, Any]:

@classmethod
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> URL:
cls,
uri: URL,
connect_args: Dict[str, Any],
selected_schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
database = uri.database
if "/" in uri.database:
database = uri.database.split("/")[0]
if selected_schema:
selected_schema = parse.quote(selected_schema, safe="")
uri = uri.set(database=f"{database}/{selected_schema}")

return uri
return uri, connect_args

@classmethod
def epoch_to_dttm(cls) -> str:
Expand Down Expand Up @@ -222,7 +225,6 @@ def build_sqlalchemy_uri(
Dict[str, Any]
] = None,
) -> str:

return str(
URL(
"snowflake",
Expand Down
15 changes: 9 additions & 6 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,15 @@ def _get_sqla_engine(
)
self.db_engine_spec.validate_database_uri(sqlalchemy_url)

sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
params = extra.get("engine_params", {})
if nullpool:
params["poolclass"] = NullPool

connect_args = params.get("connect_args", {})

sqlalchemy_url, connect_args = self.db_engine_spec.adjust_database_uri(
sqlalchemy_url, connect_args, schema
)
effective_username = self.get_effective_user(sqlalchemy_url)
# If using MySQL or Presto for example, will set url.username
# If using Hive, will not do anything yet since that relies on a
Expand All @@ -438,11 +446,6 @@ def _get_sqla_engine(
masked_url = self.get_password_masked_url(sqlalchemy_url)
logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))

params = extra.get("engine_params", {})
if nullpool:
params["poolclass"] = NullPool

connect_args = params.get("connect_args", {})
if self.impersonate_user:
self.db_engine_spec.update_impersonation_config(
connect_args, str(sqlalchemy_url), effective_username
Expand Down
12 changes: 11 additions & 1 deletion tests/integration_tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def test_impersonate_user_presto(self, mocked_create_engine):
@mock.patch("superset.models.core.create_engine")
def test_adjust_engine_params_mysql(self, mocked_create_engine):
model = Database(
database_name="test_database",
database_name="test_database1",
sqlalchemy_uri="mysql://user:password@localhost",
)
model._get_sqla_engine()
Expand All @@ -203,6 +203,16 @@ def test_adjust_engine_params_mysql(self, mocked_create_engine):
assert str(call_args[0][0]) == "mysql://user:password@localhost"
assert call_args[1]["connect_args"]["local_infile"] == 0

model = Database(
database_name="test_database2",
sqlalchemy_uri="mysql+mysqlconnector://user:password@localhost",
)
model._get_sqla_engine()
call_args = mocked_create_engine.call_args

assert str(call_args[0][0]) == "mysql+mysqlconnector://user:password@localhost"
assert call_args[1]["connect_args"]["allow_local_infile"] == 0

@mock.patch("superset.models.core.create_engine")
def test_impersonate_user_trino(self, mocked_create_engine):
principal_user = security_manager.find_user(username="gamma")
Expand Down
32 changes: 31 additions & 1 deletion tests/unit_tests/db_engine_specs/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,11 @@ def test_convert_dttm(
"sqlalchemy_uri,error",
[
("mysql://user:password@host/db1?local_infile=1", True),
("mysql+mysqlconnector://user:password@host/db1?allow_local_infile=1", True),
("mysql://user:password@host/db1?local_infile=0", True),
("mysql+mysqlconnector://user:password@host/db1?allow_local_infile=0", True),
("mysql://user:password@host/db1", False),
("mysql+mysqlconnector://user:password@host/db1", False),
],
)
def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
Expand All @@ -123,18 +126,43 @@ def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
"sqlalchemy_uri,connect_args,returns",
[
("mysql://user:password@host/db1", {"local_infile": 1}, {"local_infile": 0}),
(
"mysql+mysqlconnector://user:password@host/db1",
{"allow_local_infile": 1},
{"allow_local_infile": 0},
),
("mysql://user:password@host/db1", {"local_infile": -1}, {"local_infile": 0}),
(
"mysql+mysqlconnector://user:password@host/db1",
{"allow_local_infile": -1},
{"allow_local_infile": 0},
),
("mysql://user:password@host/db1", {"local_infile": 0}, {"local_infile": 0}),
(
"mysql+mysqlconnector://user:password@host/db1",
{"allow_local_infile": 0},
{"allow_local_infile": 0},
),
(
"mysql://user:password@host/db1",
{"param1": "some_value"},
{"local_infile": 0, "param1": "some_value"},
),
(
"mysql+mysqlconnector://user:password@host/db1",
{"param1": "some_value"},
{"allow_local_infile": 0, "param1": "some_value"},
),
(
"mysql://user:password@host/db1",
{"local_infile": 1, "param1": "some_value"},
{"local_infile": 0, "param1": "some_value"},
),
(
"mysql+mysqlconnector://user:password@host/db1",
{"allow_local_infile": 1, "param1": "some_value"},
{"allow_local_infile": 0, "param1": "some_value"},
),
],
)
def test_adjust_database_uri(
Expand All @@ -143,7 +171,9 @@ def test_adjust_database_uri(
from superset.db_engine_specs.mysql import MySQLEngineSpec

url = make_url(sqlalchemy_uri)
returned_url, returned_connect_args = MySQLEngineSpec.adjust_database_uri(url)
returned_url, returned_connect_args = MySQLEngineSpec.adjust_database_uri(
url, connect_args
)
assert returned_connect_args == returns


Expand Down

0 comments on commit 2b3aa09

Please sign in to comment.