From 1618903c383db2f7b8a6b2f328a23192cc3ec62d Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Fri, 2 Aug 2024 10:49:07 -0400 Subject: [PATCH 1/8] WIP --- scripts/python_tests.sh | 2 +- superset/cachekeys/schemas.py | 4 ++++ superset/commands/dataset/create.py | 8 +++---- .../commands/dataset/importers/v1/utils.py | 2 +- superset/commands/dataset/update.py | 7 ++++++- superset/daos/dataset.py | 20 ++++++++++++------ superset/models/core.py | 10 ++++----- superset/models/slice.py | 1 + superset/security/manager.py | 21 +++++++++++-------- .../integration_tests/dashboards/api_tests.py | 4 +++- .../db_engine_specs/postgres_tests.py | 5 +---- .../unit_tests/connectors/sqla/models_test.py | 4 ++-- tests/unit_tests/dao/dataset_test.py | 6 +++--- 13 files changed, 57 insertions(+), 37 deletions(-) diff --git a/scripts/python_tests.sh b/scripts/python_tests.sh index e127d0c020621..6f3f3bddb8317 100755 --- a/scripts/python_tests.sh +++ b/scripts/python_tests.sh @@ -33,4 +33,4 @@ superset load-test-users echo "Running tests" -pytest --durations-min=2 --maxfail=1 --cov-report= --cov=superset ./tests/integration_tests "$@" +pytest --durations-min=2 --cov-report= --cov=superset ./tests/integration_tests "$@" diff --git a/superset/cachekeys/schemas.py b/superset/cachekeys/schemas.py index e58a45ac565b9..d31e40b7d4075 100644 --- a/superset/cachekeys/schemas.py +++ b/superset/cachekeys/schemas.py @@ -32,6 +32,10 @@ class Datasource(Schema): datasource_name = fields.String( metadata={"description": datasource_name_description}, ) + catalog = fields.String( + allow_none=True, + metadata={"description": "Datasource catalog"}, + ) schema = fields.String( metadata={"description": "Datasource schema"}, ) diff --git a/superset/commands/dataset/create.py b/superset/commands/dataset/create.py index a2d81e548bfb0..aa5d62b2f2196 100644 --- a/superset/commands/dataset/create.py +++ b/superset/commands/dataset/create.py @@ -61,16 +61,16 @@ def validate(self) -> None: table = Table(self._properties["table_name"], schema, catalog) - # Validate uniqueness - if not DatasetDAO.validate_uniqueness(database_id, table): - exceptions.append(DatasetExistsValidationError(table)) - # Validate/Populate database database = DatasetDAO.get_database_by_id(database_id) if not database: exceptions.append(DatabaseNotFoundValidationError()) self._properties["database"] = database + # Validate uniqueness + if database and not DatasetDAO.validate_uniqueness(database, table): + exceptions.append(DatasetExistsValidationError(table)) + # Validate table exists on dataset if sql is not provided # This should be validated when the dataset is physical if ( diff --git a/superset/commands/dataset/importers/v1/utils.py b/superset/commands/dataset/importers/v1/utils.py index 1c508fe2522e8..58945108462fc 100644 --- a/superset/commands/dataset/importers/v1/utils.py +++ b/superset/commands/dataset/importers/v1/utils.py @@ -166,7 +166,7 @@ def import_dataset( try: table_exists = dataset.database.has_table( - Table(dataset.table_name, dataset.schema), + Table(dataset.table_name, dataset.schema, dataset.catalog), ) except Exception: # pylint: disable=broad-except # MySQL doesn't play nice with GSheets table names diff --git a/superset/commands/dataset/update.py b/superset/commands/dataset/update.py index 14d1c5ef44707..e0841c85ba1e2 100644 --- a/superset/commands/dataset/update.py +++ b/superset/commands/dataset/update.py @@ -79,10 +79,12 @@ def run(self) -> Model: def validate(self) -> None: exceptions: list[ValidationError] = [] owner_ids: Optional[list[int]] = self._properties.get("owners") + # Validate/populate model exists self._model = DatasetDAO.find_by_id(self._model_id) if not self._model: raise DatasetNotFoundError() + # Check ownership try: security_manager.raise_for_ownership(self._model) @@ -99,14 +101,16 @@ def validate(self) -> None: # Validate uniqueness if not DatasetDAO.validate_update_uniqueness( - self._model.database_id, + self._model.database, table, self._model_id, ): exceptions.append(DatasetExistsValidationError(table)) + # Validate/Populate database not allowed to change if database_id and database_id != self._model: exceptions.append(DatabaseChangeValidationError()) + # Validate/Populate owner try: owners = self.compute_owners( @@ -116,6 +120,7 @@ def validate(self) -> None: self._properties["owners"] = owners except ValidationError as ex: exceptions.append(ex) + # Validate columns if columns := self._properties.get("columns"): self._validate_columns(columns, exceptions) diff --git a/superset/daos/dataset.py b/superset/daos/dataset.py index af1b705d66109..57d498661fe88 100644 --- a/superset/daos/dataset.py +++ b/superset/daos/dataset.py @@ -84,15 +84,19 @@ def validate_table_exists( @staticmethod def validate_uniqueness( - database_id: int, + database: Database, table: Table, dataset_id: int | None = None, ) -> bool: + # The catalog might not be set even if the database supports catalogs, in case + # multi-catalog is disabled. + catalog = table.catalog or database.get_default_catalog() + dataset_query = db.session.query(SqlaTable).filter( SqlaTable.table_name == table.table, SqlaTable.schema == table.schema, - SqlaTable.catalog == table.catalog, - SqlaTable.database_id == database_id, + SqlaTable.catalog == catalog, + SqlaTable.database_id == database.id, ) if dataset_id: @@ -103,15 +107,19 @@ def validate_uniqueness( @staticmethod def validate_update_uniqueness( - database_id: int, + database: Database, table: Table, dataset_id: int, ) -> bool: + # The catalog might not be set even if the database supports catalogs, in case + # multi-catalog is disabled. + catalog = table.catalog or database.get_default_catalog() + dataset_query = db.session.query(SqlaTable).filter( SqlaTable.table_name == table.table, - SqlaTable.database_id == database_id, + SqlaTable.database_id == database.id, SqlaTable.schema == table.schema, - SqlaTable.catalog == table.catalog, + SqlaTable.catalog == catalog, SqlaTable.id != dataset_id, ) return not db.session.query(dataset_query.exists()).scalar() diff --git a/superset/models/core.py b/superset/models/core.py index 512c5a93e300e..ed550625eb242 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -116,7 +116,9 @@ class ConfigurationMethod(StrEnum): DYNAMIC_FORM = "dynamic_form" -class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods +class Database( + Model, AuditMixinNullable, ImportExportMixin +): # pylint: disable=too-many-public-methods """An ORM object that stores Database related information""" __tablename__ = "dbs" @@ -390,9 +392,7 @@ def get_effective_user(self, object_url: URL) -> str | None: return ( username if (username := get_username()) - else object_url.username - if self.impersonate_user - else None + else object_url.username if self.impersonate_user else None ) @contextmanager @@ -490,7 +490,7 @@ def _get_sqla_engine( # pylint: disable=too-many-locals g.user.id, self.db_engine_spec, ) - if hasattr(g, "user") and hasattr(g.user, "id") and oauth2_config + if oauth2_config and hasattr(g, "user") and hasattr(g.user, "id") else None ) # If using MySQL or Presto for example, will set url.username diff --git a/superset/models/slice.py b/superset/models/slice.py index c30e643b7df24..cf94a50f51eb2 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -369,6 +369,7 @@ def set_related_perm(_mapper: Mapper, _connection: Connection, target: Slice) -> ds = db.session.query(src_class).filter_by(id=int(id_)).first() if ds: target.perm = ds.perm + target.catalog_perm = ds.catalog_perm target.schema_perm = ds.schema_perm diff --git a/superset/security/manager.py b/superset/security/manager.py index d792282cbcd45..90f1199d4b9b0 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -774,6 +774,9 @@ def get_schemas_accessible_by_user( # pylint: disable=import-outside-toplevel from superset.connectors.sqla.models import SqlaTable + default_catalog = database.get_default_catalog() + catalog = catalog or default_catalog + if hierarchical and ( self.can_access_database(database) or (catalog and self.can_access_catalog(database, catalog)) @@ -783,7 +786,6 @@ def get_schemas_accessible_by_user( # schema_access accessible_schemas: set[str] = set() schema_access = self.user_view_menu_names("schema_access") - default_catalog = database.get_default_catalog() default_schema = database.get_default_schema(default_catalog) for perm in schema_access: @@ -800,7 +802,7 @@ def get_schemas_accessible_by_user( # [database].[catalog].[schema] matches when the catalog is equal to the # requested catalog or, when no catalog specified, it's equal to the default # catalog. - elif len(parts) == 3 and parts[1] == (catalog or default_catalog): + elif len(parts) == 3 and parts[1] == catalog: accessible_schemas.add(parts[2]) # datasource_access @@ -906,16 +908,16 @@ def get_datasources_accessible_by_user( # pylint: disable=invalid-name if self.can_access_database(database): return datasource_names + catalog = catalog or database.get_default_catalog() if catalog: catalog_perm = self.get_catalog_perm(database.database_name, catalog) if catalog_perm and self.can_access("catalog_access", catalog_perm): return datasource_names if schema: - default_catalog = database.get_default_catalog() schema_perm = self.get_schema_perm( database.database_name, - catalog or default_catalog, + catalog, schema, ) if schema_perm and self.can_access("schema_access", schema_perm): @@ -2183,6 +2185,7 @@ def raise_for_access( database = query.database database = cast("Database", database) + default_catalog = database.get_default_catalog() if self.can_access_database(database): return @@ -2196,19 +2199,19 @@ def raise_for_access( # from the SQLAlchemy URI if possible; if not, we use the SQLAlchemy # inspector to read it. default_schema = database.get_default_schema_for_query(query) - # Determining the default catalog is much easier, because DB engine - # specs need explicit support for catalogs. - default_catalog = database.get_default_catalog() tables = { Table( table_.table, table_.schema or default_schema, - table_.catalog or default_catalog, + table_.catalog or query.catalog or default_catalog, ) for table_ in extract_tables_from_jinja_sql(query.sql, database) } elif table: - tables = {table} + # Make sure table has the default catalog, if not specified. + tables = { + Table(table.table, table.schema, table.catalog or default_catalog) + } denied = set() diff --git a/tests/integration_tests/dashboards/api_tests.py b/tests/integration_tests/dashboards/api_tests.py index 259b9485fbe78..ff0be914aaced 100644 --- a/tests/integration_tests/dashboards/api_tests.py +++ b/tests/integration_tests/dashboards/api_tests.py @@ -2041,7 +2041,9 @@ def test_export(self): rv = self.get_assert_metric(uri, "export") - headers = "attachment; filename=dashboard_export_20220101T000000.zip" # noqa: F541 + headers = ( + "attachment; filename=dashboard_export_20220101T000000.zip" # noqa: F541 + ) assert rv.status_code == 200 assert rv.headers["Content-Disposition"] == headers diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py index f21dbf54added..1af50ae442969 100644 --- a/tests/integration_tests/db_engine_specs/postgres_tests.py +++ b/tests/integration_tests/db_engine_specs/postgres_tests.py @@ -527,11 +527,8 @@ def test_get_catalog_names(app_context: AppContext) -> None: """ database = get_example_database() - if database.backend != "postgresql": - return - with database.get_inspector() as inspector: assert PostgresEngineSpec.get_catalog_names(database, inspector) == { "postgres", - "superset", + "test", } diff --git a/tests/unit_tests/connectors/sqla/models_test.py b/tests/unit_tests/connectors/sqla/models_test.py index c1e06f3755dd6..3fa32228ca48e 100644 --- a/tests/unit_tests/connectors/sqla/models_test.py +++ b/tests/unit_tests/connectors/sqla/models_test.py @@ -255,11 +255,11 @@ def test_dataset_uniqueness(session: Session) -> None: # but the DAO enforces application logic for uniqueness assert not DatasetDAO.validate_uniqueness( - database.id, + database, Table("table", "schema", None), ) assert DatasetDAO.validate_uniqueness( - database.id, + database, Table("table", "schema", "some_catalog"), ) diff --git a/tests/unit_tests/dao/dataset_test.py b/tests/unit_tests/dao/dataset_test.py index 473d1e27b7660..2b0b5c3d5f2e6 100644 --- a/tests/unit_tests/dao/dataset_test.py +++ b/tests/unit_tests/dao/dataset_test.py @@ -53,7 +53,7 @@ def test_validate_update_uniqueness(session: Session) -> None: assert ( DatasetDAO.validate_update_uniqueness( - database_id=database.id, + database=database, table=Table(dataset1.table_name, dataset1.schema), dataset_id=dataset1.id, ) @@ -62,7 +62,7 @@ def test_validate_update_uniqueness(session: Session) -> None: assert ( DatasetDAO.validate_update_uniqueness( - database_id=database.id, + database=database, table=Table(dataset1.table_name, dataset2.schema), dataset_id=dataset1.id, ) @@ -71,7 +71,7 @@ def test_validate_update_uniqueness(session: Session) -> None: assert ( DatasetDAO.validate_update_uniqueness( - database_id=database.id, + database=database, table=Table(dataset1.table_name), dataset_id=dataset1.id, ) From 944df5cac8423ba110ee7516c1ff7bac7acaff04 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Sun, 4 Aug 2024 10:49:28 -0400 Subject: [PATCH 2/8] Fixing tests --- .../db_engine_specs/postgres_tests.py | 5 ++++- tests/integration_tests/security_tests.py | 21 ++++++++++++++----- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py index 1af50ae442969..98ce723f82c9f 100644 --- a/tests/integration_tests/db_engine_specs/postgres_tests.py +++ b/tests/integration_tests/db_engine_specs/postgres_tests.py @@ -25,6 +25,7 @@ from superset.db_engine_specs.postgres import PostgresEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.models.sql_lab import Query +from superset.utils.core import backend from superset.utils.database import get_example_database from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec from tests.integration_tests.fixtures.certificates import ssl_certificate @@ -525,8 +526,10 @@ def test_get_catalog_names(app_context: AppContext) -> None: """ Test the ``get_catalog_names`` method. """ - database = get_example_database() + if backend() != "postgresql": + return + database = get_example_database() with database.get_inspector() as inspector: assert PostgresEngineSpec.get_catalog_names(database, inspector) == { "postgres", diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index 5b8e4f2ae00e7..774fc462cc0bf 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -898,7 +898,9 @@ def test_after_update_dataset__db_changes(self): db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() ) self.assertEqual(changed_table1.perm, f"[tmp_db2].[tmp_table1](id:{table1.id})") - self.assertEqual(changed_table1.schema_perm, "[tmp_db2].[tmp_schema]") # noqa: F541 + self.assertEqual( + changed_table1.schema_perm, "[tmp_db2].[tmp_schema]" + ) # noqa: F541 # Test Chart permission changed slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one() @@ -956,12 +958,16 @@ def test_after_update_dataset__schema_changes(self): db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() ) self.assertEqual(changed_table1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})") - self.assertEqual(changed_table1.schema_perm, "[tmp_db1].[tmp_schema_changed]") # noqa: F541 + self.assertEqual( + changed_table1.schema_perm, "[tmp_db1].[tmp_schema_changed]" + ) # noqa: F541 # Test Chart schema permission changed slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one() self.assertEqual(slice1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})") - self.assertEqual(slice1.schema_perm, "[tmp_db1].[tmp_schema_changed]") # noqa: F541 + self.assertEqual( + slice1.schema_perm, "[tmp_db1].[tmp_schema_changed]" + ) # noqa: F541 # cleanup db.session.delete(slice1) @@ -1069,7 +1075,9 @@ def test_after_update_dataset__name_db_changes(self): self.assertEqual( changed_table1.perm, f"[tmp_db2].[tmp_table1_changed](id:{table1.id})" ) - self.assertEqual(changed_table1.schema_perm, "[tmp_db2].[tmp_schema]") # noqa: F541 + self.assertEqual( + changed_table1.schema_perm, "[tmp_db2].[tmp_schema]" + ) # noqa: F541 # Test Chart permission changed slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one() @@ -1633,7 +1641,10 @@ def test_raise_for_access_datasource( @patch("superset.security.SupersetSecurityManager.can_access") def test_raise_for_access_query(self, mock_can_access, mock_is_owner): query = Mock( - database=get_example_database(), schema="bar", sql="SELECT * FROM foo" + database=get_example_database(), + schema="bar", + sql="SELECT * FROM foo", + catalog=None, ) mock_can_access.return_value = True From 23dc294debd90a0a46052fdb81df644870496756 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Tue, 6 Aug 2024 11:42:41 -0400 Subject: [PATCH 3/8] Fix more tests --- superset/models/core.py | 8 ++++---- tests/integration_tests/dashboards/api_tests.py | 4 +--- tests/integration_tests/security_tests.py | 16 ++++------------ tests/unit_tests/security/manager_test.py | 15 ++++++++++++--- 4 files changed, 21 insertions(+), 22 deletions(-) diff --git a/superset/models/core.py b/superset/models/core.py index ed550625eb242..be9251b606688 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -116,9 +116,7 @@ class ConfigurationMethod(StrEnum): DYNAMIC_FORM = "dynamic_form" -class Database( - Model, AuditMixinNullable, ImportExportMixin -): # pylint: disable=too-many-public-methods +class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods """An ORM object that stores Database related information""" __tablename__ = "dbs" @@ -392,7 +390,9 @@ def get_effective_user(self, object_url: URL) -> str | None: return ( username if (username := get_username()) - else object_url.username if self.impersonate_user else None + else object_url.username + if self.impersonate_user + else None ) @contextmanager diff --git a/tests/integration_tests/dashboards/api_tests.py b/tests/integration_tests/dashboards/api_tests.py index ff0be914aaced..259b9485fbe78 100644 --- a/tests/integration_tests/dashboards/api_tests.py +++ b/tests/integration_tests/dashboards/api_tests.py @@ -2041,9 +2041,7 @@ def test_export(self): rv = self.get_assert_metric(uri, "export") - headers = ( - "attachment; filename=dashboard_export_20220101T000000.zip" # noqa: F541 - ) + headers = "attachment; filename=dashboard_export_20220101T000000.zip" # noqa: F541 assert rv.status_code == 200 assert rv.headers["Content-Disposition"] == headers diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index 774fc462cc0bf..bd76448d4899f 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -898,9 +898,7 @@ def test_after_update_dataset__db_changes(self): db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() ) self.assertEqual(changed_table1.perm, f"[tmp_db2].[tmp_table1](id:{table1.id})") - self.assertEqual( - changed_table1.schema_perm, "[tmp_db2].[tmp_schema]" - ) # noqa: F541 + self.assertEqual(changed_table1.schema_perm, "[tmp_db2].[tmp_schema]") # noqa: F541 # Test Chart permission changed slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one() @@ -958,16 +956,12 @@ def test_after_update_dataset__schema_changes(self): db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() ) self.assertEqual(changed_table1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})") - self.assertEqual( - changed_table1.schema_perm, "[tmp_db1].[tmp_schema_changed]" - ) # noqa: F541 + self.assertEqual(changed_table1.schema_perm, "[tmp_db1].[tmp_schema_changed]") # noqa: F541 # Test Chart schema permission changed slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one() self.assertEqual(slice1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})") - self.assertEqual( - slice1.schema_perm, "[tmp_db1].[tmp_schema_changed]" - ) # noqa: F541 + self.assertEqual(slice1.schema_perm, "[tmp_db1].[tmp_schema_changed]") # noqa: F541 # cleanup db.session.delete(slice1) @@ -1075,9 +1069,7 @@ def test_after_update_dataset__name_db_changes(self): self.assertEqual( changed_table1.perm, f"[tmp_db2].[tmp_table1_changed](id:{table1.id})" ) - self.assertEqual( - changed_table1.schema_perm, "[tmp_db2].[tmp_schema]" - ) # noqa: F541 + self.assertEqual(changed_table1.schema_perm, "[tmp_db2].[tmp_schema]") # noqa: F541 # Test Chart permission changed slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one() diff --git a/tests/unit_tests/security/manager_test.py b/tests/unit_tests/security/manager_test.py index 924e2cbf28ca6..660a2023e6f26 100644 --- a/tests/unit_tests/security/manager_test.py +++ b/tests/unit_tests/security/manager_test.py @@ -366,6 +366,7 @@ def test_raise_for_access_query_default_schema( database.get_default_catalog.return_value = None database.get_default_schema_for_query.return_value = "public" query = mocker.MagicMock() + query.catalog = None query.database = database query.sql = "SELECT * FROM ab_user" @@ -421,6 +422,7 @@ def test_raise_for_access_jinja_sql(mocker: MockerFixture, app_context: None) -> database.get_default_catalog.return_value = None database.get_default_schema_for_query.return_value = "public" query = mocker.MagicMock() + query.catalog = None query.database = database query.sql = "SELECT * FROM {% if True %}ab_user{% endif %} WHERE 1=1" @@ -434,7 +436,7 @@ def test_raise_for_access_jinja_sql(mocker: MockerFixture, app_context: None) -> viz=None, ) - get_table_access_error_object.assert_called_with({Table("ab_user", "public")}) + get_table_access_error_object.assert_called_with({Table("ab_user", "public", None)}) def test_raise_for_access_chart_for_datasource_permission( @@ -736,6 +738,7 @@ def test_raise_for_access_catalog( database.get_default_catalog.return_value = "db1" database.get_default_schema_for_query.return_value = "public" query = mocker.MagicMock() + query.catalog = "db1" query.database = database query.sql = "SELECT * FROM ab_user" @@ -776,7 +779,8 @@ def test_get_datasources_accessible_by_user_schema_access( database.database_name = "db1" database.get_default_catalog.return_value = "catalog2" - can_access = mocker.patch.object(sm, "can_access", return_value=True) + # False for catalog_access, True for schema_access + can_access = mocker.patch.object(sm, "can_access", side_effect=[False, True]) datasource_names = [ DatasourceName("table1", "schema1", "catalog2"), @@ -795,7 +799,12 @@ def test_get_datasources_accessible_by_user_schema_access( # Even though we passed `catalog=None,` the schema check uses the default catalog # when building the schema permission, since the DB supports catalog. - can_access.assert_called_with("schema_access", "[db1].[catalog2].[schema1]") + can_access.assert_has_calls( + [ + mocker.call("catalog_access", "[db1].[catalog2]"), + mocker.call("schema_access", "[db1].[catalog2].[schema1]"), + ] + ) def test_get_catalogs_accessible_by_user_schema_access( From fd9290a77e2950824e50117be041524eb4426298 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Tue, 6 Aug 2024 12:43:55 -0400 Subject: [PATCH 4/8] Fix more tests --- scripts/change_detector.py | 2 +- tests/integration_tests/db_engine_specs/postgres_tests.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/change_detector.py b/scripts/change_detector.py index 39e4a5c8ccd25..f52cd59fec45f 100755 --- a/scripts/change_detector.py +++ b/scripts/change_detector.py @@ -52,7 +52,7 @@ def fetch_files_github_api(url: str): # type: ignore """Fetches data using GitHub API.""" req = Request(url) - req.add_header("Authorization", f"token {GITHUB_TOKEN}") + req.add_header("Authorization", f"Bearer {GITHUB_TOKEN}") req.add_header("Accept", "application/vnd.github.v3+json") print(f"Fetching from {url}") diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py index 98ce723f82c9f..175ee65b2d0e2 100644 --- a/tests/integration_tests/db_engine_specs/postgres_tests.py +++ b/tests/integration_tests/db_engine_specs/postgres_tests.py @@ -533,5 +533,5 @@ def test_get_catalog_names(app_context: AppContext) -> None: with database.get_inspector() as inspector: assert PostgresEngineSpec.get_catalog_names(database, inspector) == { "postgres", - "test", + "superset", } From 03059e81049380dfdc85895a428d27507c0114fe Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Tue, 6 Aug 2024 15:03:36 -0400 Subject: [PATCH 5/8] Remove broken tests --- .../integration_tests/databases/api_tests.py | 28 ------- tests/integration_tests/datasets/api_tests.py | 73 ------------------- 2 files changed, 101 deletions(-) diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 8d0cd0810f8b1..f2ff214a78b21 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -1563,34 +1563,6 @@ def test_get_select_star_not_allowed(self): rv = self.client.get(uri) self.assertEqual(rv.status_code, 404) - def test_get_select_star_datasource_access(self): - """ - Database API: Test get select star with datasource access - """ - table = SqlaTable( - schema="main", table_name="ab_permission", database=get_main_database() - ) - db.session.add(table) - db.session.commit() - - tmp_table_perm = security_manager.find_permission_view_menu( - "datasource_access", table.get_perm() - ) - gamma_role = security_manager.find_role("Gamma") - security_manager.add_permission_role(gamma_role, tmp_table_perm) - - self.login(GAMMA_USERNAME) - main_db = get_main_database() - uri = f"api/v1/database/{main_db.id}/select_star/ab_permission/" - rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) - - # rollback changes - security_manager.del_permission_role(gamma_role, tmp_table_perm) - db.session.delete(table) - db.session.delete(main_db) - db.session.commit() - def test_get_select_star_not_found_database(self): """ Database API: Test get select star not found database diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 37de6e87c27ad..9a6789bbbfd67 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -36,7 +36,6 @@ from superset.extensions import db, security_manager from superset.models.core import Database from superset.models.slice import Slice -from superset.sql_parse import Table from superset.utils import json from superset.utils.core import backend, get_example_default_schema from superset.utils.database import get_example_database, get_main_database @@ -676,57 +675,6 @@ def test_create_dataset_item_owners_invalid(self): expected_result = {"message": {"owners": ["Owners are invalid"]}} assert data == expected_result - @pytest.mark.usefixtures("load_energy_table_with_slice") - def test_create_dataset_validate_uniqueness(self): - """ - Dataset API: Test create dataset validate table uniqueness - """ - - energy_usage_ds = self.get_energy_usage_dataset() - self.login(ADMIN_USERNAME) - table_data = { - "database": energy_usage_ds.database_id, - "table_name": energy_usage_ds.table_name, - } - if schema := get_example_default_schema(): - table_data["schema"] = schema - rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post") - assert rv.status_code == 422 - data = json.loads(rv.data.decode("utf-8")) - assert data == { - "message": { - "table": [ - f"Dataset {Table(energy_usage_ds.table_name, schema)} already exists" - ] - } - } - - @pytest.mark.usefixtures("load_energy_table_with_slice") - def test_create_dataset_with_sql_validate_uniqueness(self): - """ - Dataset API: Test create dataset with sql - """ - - energy_usage_ds = self.get_energy_usage_dataset() - self.login(ADMIN_USERNAME) - table_data = { - "database": energy_usage_ds.database_id, - "table_name": energy_usage_ds.table_name, - "sql": "select * from energy_usage", - } - if schema := get_example_default_schema(): - table_data["schema"] = schema - rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post") - assert rv.status_code == 422 - data = json.loads(rv.data.decode("utf-8")) - assert data == { - "message": { - "table": [ - f"Dataset {Table(energy_usage_ds.table_name, schema)} already exists" - ] - } - } - @pytest.mark.usefixtures("load_energy_table_with_slice") def test_create_dataset_with_sql(self): """ @@ -1455,27 +1403,6 @@ def test_update_dataset_item_owners_invalid(self): db.session.delete(dataset) db.session.commit() - def test_update_dataset_item_uniqueness(self): - """ - Dataset API: Test update dataset uniqueness - """ - - dataset = self.insert_default_dataset() - self.login(ADMIN_USERNAME) - ab_user = self.insert_dataset( - "ab_user", [self.get_user("admin").id], get_main_database() - ) - table_data = {"table_name": "ab_user"} - uri = f"api/v1/dataset/{dataset.id}" - rv = self.put_assert_metric(uri, table_data, "put") - data = json.loads(rv.data.decode("utf-8")) - assert rv.status_code == 422 - expected_response = {"message": {"table": ["Dataset ab_user already exists"]}} - assert data == expected_response - db.session.delete(dataset) - db.session.delete(ab_user) - db.session.commit() - @patch("superset.daos.dataset.DatasetDAO.update") def test_update_dataset_sqlalchemy_error(self, mock_dao_update): """ From 5e7d979bcbab0c288feb19a6b8e7925632fbc9ef Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Tue, 6 Aug 2024 16:14:51 -0400 Subject: [PATCH 6/8] Fix SQL Lab --- superset/commands/dataset/create.py | 15 ++++++++++----- superset/commands/dataset/update.py | 8 +++++++- superset/databases/api.py | 2 +- superset/jinja_context.py | 3 ++- superset/sqllab/api.py | 1 + superset/sqllab/sqllab_execution_context.py | 2 ++ superset/views/sql_lab/views.py | 2 ++ 7 files changed, 25 insertions(+), 8 deletions(-) diff --git a/superset/commands/dataset/create.py b/superset/commands/dataset/create.py index aa5d62b2f2196..ae6a0af4ed324 100644 --- a/superset/commands/dataset/create.py +++ b/superset/commands/dataset/create.py @@ -54,13 +54,12 @@ def run(self) -> Model: def validate(self) -> None: exceptions: list[ValidationError] = [] database_id = self._properties["database"] - schema = self._properties.get("schema") catalog = self._properties.get("catalog") + schema = self._properties.get("schema") + table_name = self._properties["table_name"] sql = self._properties.get("sql") owner_ids: Optional[list[int]] = self._properties.get("owners") - table = Table(self._properties["table_name"], schema, catalog) - # Validate/Populate database database = DatasetDAO.get_database_by_id(database_id) if not database: @@ -68,8 +67,14 @@ def validate(self) -> None: self._properties["database"] = database # Validate uniqueness - if database and not DatasetDAO.validate_uniqueness(database, table): - exceptions.append(DatasetExistsValidationError(table)) + if database: + if not catalog: + catalog = self._properties["catalog"] = database.get_default_catalog() + + table = Table(table_name, schema, catalog) + + if not DatasetDAO.validate_uniqueness(database, table): + exceptions.append(DatasetExistsValidationError(table)) # Validate table exists on dataset if sql is not provided # This should be validated when the dataset is physical diff --git a/superset/commands/dataset/update.py b/superset/commands/dataset/update.py index e0841c85ba1e2..2772cc0ffa1f6 100644 --- a/superset/commands/dataset/update.py +++ b/superset/commands/dataset/update.py @@ -93,10 +93,16 @@ def validate(self) -> None: database_id = self._properties.get("database") + catalog = self._properties.get("catalog") + if not catalog: + catalog = self._properties["catalog"] = ( + self._model.database.get_default_catalog() + ) + table = Table( self._properties.get("table_name"), # type: ignore self._properties.get("schema"), - self._properties.get("catalog"), + catalog, ) # Validate uniqueness diff --git a/superset/databases/api.py b/superset/databases/api.py index d490ac70dab55..695ea028b476d 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -1159,7 +1159,7 @@ def select_star( self.incr_stats("init", self.select_star.__name__) try: result = database.select_star( - Table(table_name, schema_name), + Table(table_name, schema_name, database.get_default_catalog()), latest_partition=True, ) except NoSuchTableError: diff --git a/superset/jinja_context.py b/superset/jinja_context.py index 8d59eade155b8..10428db34e85a 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -501,6 +501,7 @@ def process_template(self, sql: str, **kwargs: Any) -> str: kwargs.update(self._context) context = validate_template_context(self.engine, kwargs) + print("FOO", type(template.render(context))) return template.render(context) @@ -565,7 +566,7 @@ def process_template(self, sql: str, **kwargs: Any) -> str: """ Makes processing a template a noop """ - return sql + return str(sql) class PrestoTemplateProcessor(JinjaTemplateProcessor): diff --git a/superset/sqllab/api.py b/superset/sqllab/api.py index cdb331c19bc33..f7d66ed4e19fa 100644 --- a/superset/sqllab/api.py +++ b/superset/sqllab/api.py @@ -284,6 +284,7 @@ def export_csv(self, client_id: str) -> CsvResponse: "client_id": client_id, "row_count": row_count, "database": query.database.name, + "catalog": query.catalog, "schema": query.schema, "sql": query.sql, "exported_format": "csv", diff --git a/superset/sqllab/sqllab_execution_context.py b/superset/sqllab/sqllab_execution_context.py index 5ca180d101b55..ab0f91bbf30ca 100644 --- a/superset/sqllab/sqllab_execution_context.py +++ b/superset/sqllab/sqllab_execution_context.py @@ -125,6 +125,8 @@ def select_as_cta(self) -> bool: def set_database(self, database: Database) -> None: self._validate_db(database) self.database = database + if self.catalog is None: + self.catalog = database.get_default_catalog() if self.select_as_cta: schema_name = self._get_ctas_target_schema_name(database) self.create_table_as_select.target_schema_name = schema_name # type: ignore diff --git a/superset/views/sql_lab/views.py b/superset/views/sql_lab/views.py index 3ec3667267471..3b24f7c0eca3b 100644 --- a/superset/views/sql_lab/views.py +++ b/superset/views/sql_lab/views.py @@ -239,6 +239,7 @@ def post(self) -> FlaskResponse: db.session.query(TableSchema).filter( TableSchema.tab_state_id == table["queryEditorId"], TableSchema.database_id == table["dbId"], + TableSchema.catalog == table["catalog"], TableSchema.schema == table["schema"], TableSchema.table == table["name"], ).delete(synchronize_session=False) @@ -246,6 +247,7 @@ def post(self) -> FlaskResponse: table_schema = TableSchema( tab_state_id=table["queryEditorId"], database_id=table["dbId"], + catalog=table["catalog"], schema=table["schema"], table=table["name"], description=json.dumps(table), From 9023b487ad8b9a67dfc9674042480e9a759ecaf5 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Tue, 6 Aug 2024 17:08:37 -0400 Subject: [PATCH 7/8] Fix cache keys --- superset/connectors/sqla/models.py | 10 ++++++---- superset/jinja_context.py | 1 - tests/integration_tests/sqla_models_tests.py | 6 ++++-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index c38a0085a534b..4b2126e491c2a 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -461,9 +461,11 @@ def data_for_slices( # pylint: disable=too-many-locals ) else: _columns = [ - utils.get_column_name(column_) - if utils.is_adhoc_column(column_) - else column_ + ( + utils.get_column_name(column_) + if utils.is_adhoc_column(column_) + else column_ + ) for column_param in COLUMN_FORM_DATA_PARAMS for column_ in utils.as_list(form_data.get(column_param) or []) ] @@ -1963,7 +1965,7 @@ class and any keys added via `ExtraCache`. if self.has_extra_cache_key_calls(query_obj): sqla_query = self.get_sqla_query(**query_obj) extra_cache_keys += sqla_query.extra_cache_keys - return extra_cache_keys + return list(set(extra_cache_keys)) @property def quote_identifier(self) -> Callable[[str], str]: diff --git a/superset/jinja_context.py b/superset/jinja_context.py index 10428db34e85a..03ff43e8f4bcd 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -501,7 +501,6 @@ def process_template(self, sql: str, **kwargs: Any) -> str: kwargs.update(self._context) context = validate_template_context(self.engine, kwargs) - print("FOO", type(template.render(context))) return template.render(context) diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 86fffee1ec89a..4398d75c12f92 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -34,7 +34,9 @@ from superset.constants import EMPTY_STRING, NULL_STRING from superset.db_engine_specs.bigquery import BigQueryEngineSpec from superset.db_engine_specs.druid import DruidEngineSpec -from superset.exceptions import QueryObjectValidationError, SupersetSecurityException # noqa: F401 +from superset.exceptions import ( + QueryObjectValidationError, +) # noqa: F401 from superset.models.core import Database from superset.utils.core import ( AdhocMetricExpressionType, @@ -160,7 +162,7 @@ def test_extra_cache_keys(self, mock_user_email, mock_username, mock_user_id): query_obj = dict(**base_query_obj, extras={}) extra_cache_keys = table1.get_extra_cache_keys(query_obj) self.assertTrue(table1.has_extra_cache_key_calls(query_obj)) - assert extra_cache_keys == [1, "abc", "abc@test.com"] + assert set(extra_cache_keys) == {1, "abc", "abc@test.com"} # Table with Jinja callable disabled. table2 = SqlaTable( From c5a97978f38d6da93df27f26c33a02c46abc9a8c Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Tue, 6 Aug 2024 18:51:47 -0400 Subject: [PATCH 8/8] Undo test change --- scripts/python_tests.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/python_tests.sh b/scripts/python_tests.sh index 6f3f3bddb8317..e127d0c020621 100755 --- a/scripts/python_tests.sh +++ b/scripts/python_tests.sh @@ -33,4 +33,4 @@ superset load-test-users echo "Running tests" -pytest --durations-min=2 --cov-report= --cov=superset ./tests/integration_tests "$@" +pytest --durations-min=2 --maxfail=1 --cov-report= --cov=superset ./tests/integration_tests "$@"