diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 1329597f02b72..0b4212197b14a 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -92,6 +92,12 @@ logger = logging.getLogger() +# When connecting to a database it's hard to catch specific exceptions, since we support +# more than 50 different database drivers. Usually the try/except block will catch the +# generic `Exception` class, which requires a pylint disablee comment. To make it clear +# that we know this is a necessary evil we create an alias, and catch it instead. +GenericDBException = Exception + def convert_inspector_columns(cols: list[SQLAColumnType]) -> list[ResultSetColumnType]: result_set_columns: list[ResultSetColumnType] = [] @@ -406,7 +412,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # # When this is changed to true in a DB engine spec it MUST support the # `get_default_catalog` and `get_catalog_names` methods. In addition, you MUST write - # a database migration updating any existing schema permissions. + # a database migration updating any existing schema permissions using the helper + # `upgrade_catalog_perms`. supports_catalog = False # Can the catalog be changed on a per-query basis? diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py index d487f682aed6b..91860bf5b352d 100644 --- a/superset/db_engine_specs/databricks.py +++ b/superset/db_engine_specs/databricks.py @@ -434,8 +434,8 @@ def get_default_catalog( cls, database: Database, ) -> str | None: - with database.get_inspector() as inspector: - return inspector.bind.execute("SELECT current_catalog()").scalar() + with database.get_sqla_engine() as engine: + return engine.execute("SELECT current_catalog()").scalar() @classmethod def get_prequeries( diff --git a/superset/migrations/shared/catalogs.py b/superset/migrations/shared/catalogs.py index 27952371bdebf..b09c71739f8b7 100644 --- a/superset/migrations/shared/catalogs.py +++ b/superset/migrations/shared/catalogs.py @@ -26,8 +26,13 @@ from sqlalchemy.orm import Session from superset import db, security_manager -from superset.daos.database import DatabaseDAO -from superset.migrations.shared.security_converge import add_pvms, ViewMenu +from superset.db_engine_specs.base import GenericDBException +from superset.migrations.shared.security_converge import ( + add_pvms, + Permission, + PermissionView, + ViewMenu, +) from superset.models.core import Database logger = logging.getLogger(__name__) @@ -41,7 +46,9 @@ class SqlaTable(Base): id = sa.Column(sa.Integer, primary_key=True) database_id = sa.Column(sa.Integer, nullable=False) + perm = sa.Column(sa.String(1000)) schema_perm = sa.Column(sa.String(1000)) + catalog_perm = sa.Column(sa.String(1000), nullable=True, default=None) schema = sa.Column(sa.String(255)) catalog = sa.Column(sa.String(256), nullable=True, default=None) @@ -84,41 +91,47 @@ class Slice(Base): id = sa.Column(sa.Integer, primary_key=True) datasource_id = sa.Column(sa.Integer) datasource_type = sa.Column(sa.String(200)) + catalog_perm = sa.Column(sa.String(1000), nullable=True, default=None) schema_perm = sa.Column(sa.String(1000)) -def get_schemas(database_name: str) -> list[str]: +def get_known_schemas(database_name: str, session: Session) -> list[str]: """ - Read all known schemas from the schema permissions. + Read all known schemas from the existing schema permissions. """ - query = f""" -SELECT - avm.name -FROM ab_view_menu avm -JOIN ab_permission_view apv ON avm.id = apv.view_menu_id -JOIN ab_permission ap ON apv.permission_id = ap.id -WHERE - avm.name LIKE '[{database_name}]%' AND - ap.name = 'schema_access'; - """ - # [PostgreSQL].[postgres].[public] => public - conn = op.get_bind() - return sorted({row[0].split(".")[-1][1:-1] for row in conn.execute(query)}) + names = ( + session.query(ViewMenu.name) + .join(PermissionView, ViewMenu.id == PermissionView.view_menu_id) + .join(Permission, PermissionView.permission_id == Permission.id) + .filter( + ViewMenu.name.like(f"[{database_name}]%"), + Permission.name == "schema_access", + ) + .all() + ) + return sorted({name[0][1:-1].split("].[")[-1] for name in names}) def upgrade_catalog_perms(engines: set[str] | None = None) -> None: """ - Update models when catalogs are introduced in a DB engine spec. + Update models and permissions when catalogs are introduced in a DB engine spec. When an existing DB engine spec starts to support catalogs we need to: - - Add a `catalog_access` permission for each catalog. - - Populate the `catalog` field with the default catalog for each related model. + - Add `catalog_access` permissions for each catalog. + - Rename existing `schema_access` permissions to include the default catalog. + - Create `schema_access` permissions for each schema in the new catalogs. + + Also, for all the relevant existing models we need to: + + - Populate the `catalog` field with the default catalog. - Update `schema_perm` to include the default catalog. + - Populate `catalog_perm` to include the default catalog. """ bind = op.get_bind() session = db.Session(bind=bind) + for database in session.query(Database).all(): db_engine_spec = database.db_engine_spec if ( @@ -126,83 +139,204 @@ def upgrade_catalog_perms(engines: set[str] | None = None) -> None: ) or not db_engine_spec.supports_catalog: continue - catalog = database.get_default_catalog() - if catalog is None: - continue + # For some databases, fetching the default catalog requires a connection to the + # analytical DB. If we can't connect to the analytical DB during the migration + # we should stop it, since we need the default catalog in order to update + # existing models. + if default_catalog := database.get_default_catalog(): + upgrade_database_catalogs(database, default_catalog, session) + + session.flush() - perm = security_manager.get_catalog_perm( + +def upgrade_database_catalogs( + database: Database, + default_catalog: str, + session: Session, +) -> None: + """ + Upgrade a given database to support the default catalog. + """ + catalog_perm = security_manager.get_catalog_perm( + database.database_name, + default_catalog, + ) + pvms: dict[str, tuple[str, ...]] = {catalog_perm: ("catalog_access",)} + + # rename existing schema permissions to include the catalog, and also find any new + # schemas + new_schema_pvms = upgrade_schema_perms(database, default_catalog, session) + pvms.update(new_schema_pvms) + + # update existing models that have a `catalog` column so it points to the default + # catalog + models = [ + (Query, "database_id"), + (SavedQuery, "db_id"), + (TabState, "database_id"), + (TableSchema, "database_id"), + ] + for model, column in models: + for instance in session.query(model).filter( + getattr(model, column) == database.id + ): + instance.catalog = default_catalog + + # update `schema_perm` and `catalog_perm` for tables and charts + for table in session.query(SqlaTable).filter_by( + database_id=database.id, + catalog=None, + ): + schema_perm = security_manager.get_schema_perm( database.database_name, - catalog, + default_catalog, + table.schema, ) - add_pvms(session, {perm: ("catalog_access",)}) - - upgrade_schema_perms(database, catalog, session) - - # update existing models - models = [ - (Query, "database_id"), - (SavedQuery, "db_id"), - (TabState, "database_id"), - (TableSchema, "database_id"), - (SqlaTable, "database_id"), - ] - for model, column in models: - for instance in session.query(model).filter( - getattr(model, column) == database.id - ): - instance.catalog = catalog - - for table in session.query(SqlaTable).filter_by(database_id=database.id): - schema_perm = security_manager.get_schema_perm( - database.database_name, - catalog, - table.schema, - ) - table.schema_perm = schema_perm - for chart in session.query(Slice).filter_by( - datasource_id=table.id, - datasource_type="table", - ): - chart.schema_perm = schema_perm - session.commit() + table.catalog = default_catalog + table.catalog_perm = catalog_perm + table.schema_perm = schema_perm + + for chart in session.query(Slice).filter_by( + datasource_id=table.id, + datasource_type="table", + ): + chart.catalog_perm = catalog_perm + chart.schema_perm = schema_perm + # add any new catalogs discovered and their schemas + new_catalog_pvms = add_non_default_catalogs(database, default_catalog, session) + pvms.update(new_catalog_pvms) -def upgrade_schema_perms(database: Database, catalog: str, session: Session) -> None: + # add default catalog permission and permissions for any new found schemas, and also + # permissions for new catalogs and their schemas + add_pvms(session, pvms) + + +def add_non_default_catalogs( + database: Database, + default_catalog: str, + session: Session, +) -> dict[str, tuple[str]]: """ - Rename existing schema permissions to include the catalog. + Add permissions for additional catalogs and their schemas. """ - ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) try: - schemas = database.get_all_schema_names( - catalog=catalog, - cache=False, - ssh_tunnel=ssh_tunnel, - ) - except Exception: # pylint: disable=broad-except - schemas = get_schemas(database.database_name) + catalogs = { + catalog + for catalog in database.get_all_catalog_names() + if catalog != default_catalog + } + except GenericDBException: + # If we can't connect to the analytical DB to fetch the catalogs we should just + # return. The catalog and schema permissions can be created later when the DB is + # edited. + return {} + + pvms = {} + for catalog in catalogs: + perm = security_manager.get_catalog_perm(database.database_name, catalog) + pvms[perm] = ("catalog_access",) + + new_schema_pvms = create_schema_perms(database, catalog, session) + pvms.update(new_schema_pvms) + + return pvms + + +def upgrade_schema_perms( + database: Database, + default_catalog: str, + session: Session, +) -> dict[str, tuple[str]]: + """ + Rename existing schema permissions to include the catalog. + + Schema permissions are stored (and processed) as strings, in the form: + + [database_name].[schema_name] + + When catalogs are first introduced for a DB engine spec we need to rename any + existing permissions to the form: + [database_name].[default_catalog_name].[schema_name] + + """ + schemas = get_known_schemas(database.database_name, session) + + perms = {} for schema in schemas: - perm = security_manager.get_schema_perm( + current_perm = security_manager.get_schema_perm( database.database_name, None, schema, ) - existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none() - if existing_pvm: - existing_pvm.name = security_manager.get_schema_perm( - database.database_name, - catalog, - schema, - ) + new_perm = security_manager.get_schema_perm( + database.database_name, + default_catalog, + schema, + ) + + if ( + existing_pvm := session.query(ViewMenu) + .filter_by(name=current_perm) + .one_or_none() + ): + existing_pvm.name = new_perm + else: + # new schema discovered, need to create a new permission + perms[new_perm] = ("schema_access",) + + return perms + + +def create_schema_perms( + database: Database, + catalog: str, + session: Session, +) -> dict[str, tuple[str]]: + """ + Create schema permissions for a given catalog. + """ + try: + schemas = database.get_all_schema_names(catalog=catalog) + except GenericDBException: + # If we can't connect to the analytical DB to fetch schemas in this catalog we + # should just return. The schema permissions can be created when the DB is + # edited. + return {} + + return { + security_manager.get_schema_perm( + database.database_name, + catalog, + schema, + ): ("schema_access",) + for schema in schemas + } def downgrade_catalog_perms(engines: set[str] | None = None) -> None: """ Reverse the process of `upgrade_catalog_perms`. + + This should: + + - Delete all `catalog_access` permissions. + - Rename `schema_access` permissions in the default catalog to omit it. + - Delete `schema_access` permissions for schemas not in the default catalog. + + Also, for models in the default catalog we should: + + - Populate the `catalog` field with `None`. + - Update `schema_perm` to omit the default catalog. + - Populate the `catalog_perm` field with `None`. + + WARNING: models (datasets and charts) not in the default catalog are deleted! """ bind = op.get_bind() session = db.Session(bind=bind) + for database in session.query(Database).all(): db_engine_spec = database.db_engine_spec if ( @@ -210,70 +344,155 @@ def downgrade_catalog_perms(engines: set[str] | None = None) -> None: ) or not db_engine_spec.supports_catalog: continue - catalog = database.get_default_catalog() - if catalog is None: - continue + if default_catalog := database.get_default_catalog(): + downgrade_database_catalogs(database, default_catalog, session) + + session.flush() + + +def downgrade_database_catalogs( + database: Database, + default_catalog: str, + session: Session, +) -> None: + # remove all catalog permissions associated with the DB + prefix = f"[{database.database_name}].%" + for pvm in ( + session.query(PermissionView) + .join(Permission, PermissionView.permission_id == Permission.id) + .join(ViewMenu, PermissionView.view_menu_id == ViewMenu.id) + .filter( + Permission.name == "catalog_access", + ViewMenu.name.like(prefix), + ) + .all() + ): + session.delete(pvm) + session.delete(pvm.view_menu) + + # rename existing schemas permissions to omit the catalog, and remove schema + # permissions associated with other catalogs + downgrade_schema_perms(database, default_catalog, session) + + # update existing models + models = [ + (Query, "database_id"), + (SavedQuery, "db_id"), + (TabState, "database_id"), + (TableSchema, "database_id"), + ] + for model, column in models: + for instance in session.query(model).filter( + getattr(model, column) == database.id, + model.catalog == default_catalog, # type: ignore + ): + instance.catalog = None + + # update `schema_perm` for tables and charts + for table in session.query(SqlaTable).filter_by( + database_id=database.id, + catalog=default_catalog, + ): + schema_perm = security_manager.get_schema_perm( + database.database_name, + None, + table.schema, + ) - downgrade_schema_perms(database, catalog, session) - - # update existing models - models = [ - (Query, "database_id"), - (SavedQuery, "db_id"), - (TabState, "database_id"), - (TableSchema, "database_id"), - (SqlaTable, "database_id"), - ] - for model, column in models: - for instance in session.query(model).filter( - getattr(model, column) == database.id - ): - instance.catalog = None - - for table in session.query(SqlaTable).filter_by(database_id=database.id): - schema_perm = security_manager.get_schema_perm( - database.database_name, - None, - table.schema, + table.catalog = None + table.catalog_perm = None + table.schema_perm = schema_perm + + for chart in session.query(Slice).filter_by( + datasource_id=table.id, + datasource_type="table", + ): + chart.catalog_perm = None + chart.schema_perm = schema_perm + + # delete models referencing non-default catalogs + for model, column in models: + for instance in session.query(model).filter( + getattr(model, column) == database.id, + model.catalog != default_catalog, # type: ignore + ): + session.delete(instance) + + # delete datasets and any associated permissions + for table in session.query(SqlaTable).filter( + SqlaTable.database_id == database.id, + SqlaTable.catalog != default_catalog, + ): + for chart in session.query(Slice).filter( + Slice.datasource_id == table.id, + Slice.datasource_type == "table", + ): + session.delete(chart) + + session.delete(table) + pvm = ( + session.query(PermissionView) + .join(Permission, PermissionView.permission_id == Permission.id) + .join(ViewMenu, PermissionView.view_menu_id == ViewMenu.id) + .filter( + Permission.name == "datasource_access", + ViewMenu.name == table.perm, ) - table.schema_perm = schema_perm - for chart in session.query(Slice).filter_by( - datasource_id=table.id, - datasource_type="table", - ): - chart.schema_perm = schema_perm + .one() + ) + session.delete(pvm) + session.delete(pvm.view_menu) - session.commit() + session.flush() -def downgrade_schema_perms(database: Database, catalog: str, session: Session) -> None: +def downgrade_schema_perms( + database: Database, + default_catalog: str, + session: Session, +) -> None: """ - Rename existing schema permissions to omit the catalog. + Rename default catalog schema permissions and delete other schema permissions. """ - ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) - try: - schemas = database.get_all_schema_names( - catalog=catalog, - cache=False, - ssh_tunnel=ssh_tunnel, + prefix = f"[{database.database_name}].%" + pvms = ( + session.query(PermissionView) + .join(Permission, PermissionView.permission_id == Permission.id) + .join(ViewMenu, PermissionView.view_menu_id == ViewMenu.id) + .filter( + Permission.name == "schema_access", + ViewMenu.name.like(prefix), ) - except Exception: # pylint: disable=broad-except - schemas = get_schemas(database.database_name) + .all() + ) + + pvms_to_delete = [] + pvms_to_rename = [] + for pvm in pvms: + parts = pvm.view_menu.name[1:-1].split("].[") + if len(parts) != 3: + logger.warning( + "Invalid schema permission: %s. Please fix manually", + pvm.view_menu.name, + ) + continue - for schema in schemas: - perm = security_manager.get_schema_perm( - database.database_name, - catalog, - schema, - ) - existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none() - if existing_pvm: - new_perm = security_manager.get_schema_perm( - database.database_name, + database_name, catalog, schema = parts + + if catalog == default_catalog: + new_name = security_manager.get_schema_perm( + database_name, None, schema, ) - if pvm := session.query(ViewMenu).filter_by(name=new_perm).one_or_none(): - session.delete(pvm) - session.flush() - existing_pvm.name = new_perm + pvms_to_rename.append((pvm, new_name)) + else: + # non-default catalog, delete schema perm + pvms_to_delete.append(pvm) + + for pvm in pvms_to_delete: + session.delete(pvm) + session.delete(pvm.view_menu) + + for pvm, new_name in pvms_to_rename: + pvm.view_menu.name = new_name diff --git a/tests/unit_tests/migrations/shared/catalogs_test.py b/tests/unit_tests/migrations/shared/catalogs_test.py index 78ef5222171d7..56d202eaca61c 100644 --- a/tests/unit_tests/migrations/shared/catalogs_test.py +++ b/tests/unit_tests/migrations/shared/catalogs_test.py @@ -22,17 +22,18 @@ downgrade_catalog_perms, upgrade_catalog_perms, ) -from superset.migrations.shared.security_converge import ViewMenu +from superset.migrations.shared.security_converge import ( + Permission, + PermissionView, + ViewMenu, +) def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None: """ Test the `upgrade_catalog_perms` function. - The function is called when catalogs are introduced into a new DB engine spec. When - that happens, we need to update the `catalog` attribute so it points to the default - catalog, instead of being `NULL`. We also need to update `schema_perms` to include - the default catalog. + The function is called when catalogs are introduced into a new DB engine spec. """ from superset.connectors.sqla.models import SqlaTable from superset.models.core import Database @@ -51,6 +52,11 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None: "get_all_schema_names", return_value=["public", "information_schema"], ) + mocker.patch.object( + Database, + "get_all_catalog_names", + return_value=["db", "other_catalog"], + ) database = Database( database_name="my_db", @@ -61,6 +67,7 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None: database=database, catalog=None, schema="public", + catalog_perm=None, schema_perm="[my_db].[public]", ) session.add(dataset) @@ -70,6 +77,8 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None: slice_name="my_chart", datasource_type="table", datasource_id=dataset.id, + catalog_perm=None, + schema_perm="[my_db].[public]", ) query = Query( client_id="foo", @@ -102,15 +111,43 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None: assert saved_query.catalog is None assert tab_state.catalog is None assert table_schema.catalog is None + assert dataset.catalog_perm is None assert dataset.schema_perm == "[my_db].[public]" + assert chart.catalog_perm is None assert chart.schema_perm == "[my_db].[public]" - assert session.query(ViewMenu.name).all() == [ - ("[my_db].(id:1)",), - ("[my_db].[my_table](id:1)",), - ("[my_db].[public]",), + assert ( + session.query(ViewMenu.name, Permission.name) + .join(PermissionView, ViewMenu.id == PermissionView.view_menu_id) + .join(Permission, PermissionView.permission_id == Permission.id) + .all() + ) == [ + ("[my_db].(id:1)", "database_access"), + ("[my_db].[my_table](id:1)", "datasource_access"), + ("[my_db].[public]", "schema_access"), ] upgrade_catalog_perms() + session.commit() + + # add dataset/chart in new catalog + new_dataset = SqlaTable( + table_name="my_table", + database=database, + catalog="other_catalog", + schema="public", + schema_perm="[my_db].[other_catalog].[public]", + catalog_perm="[my_db].[other_catalog]", + ) + session.add(new_dataset) + session.commit() + + new_chart = Slice( + slice_name="my_chart", + datasource_type="table", + datasource_id=new_dataset.id, + ) + session.add(new_chart) + session.commit() # after migration assert dataset.catalog == "db" @@ -118,16 +155,29 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None: assert saved_query.catalog == "db" assert tab_state.catalog == "db" assert table_schema.catalog == "db" + assert dataset.catalog_perm == "[my_db].[db]" assert dataset.schema_perm == "[my_db].[db].[public]" + assert chart.catalog_perm == "[my_db].[db]" assert chart.schema_perm == "[my_db].[db].[public]" - assert session.query(ViewMenu.name).all() == [ - ("[my_db].(id:1)",), - ("[my_db].[my_table](id:1)",), - ("[my_db].[db].[public]",), - ("[my_db].[db]",), + assert ( + session.query(ViewMenu.name, Permission.name) + .join(PermissionView, ViewMenu.id == PermissionView.view_menu_id) + .join(Permission, PermissionView.permission_id == Permission.id) + .all() + ) == [ + ("[my_db].(id:1)", "database_access"), + ("[my_db].[my_table](id:1)", "datasource_access"), + ("[my_db].[db].[public]", "schema_access"), + ("[my_db].[db]", "catalog_access"), + ("[my_db].[other_catalog]", "catalog_access"), + ("[my_db].[other_catalog].[public]", "schema_access"), + ("[my_db].[other_catalog].[information_schema]", "schema_access"), + ("[my_db].[my_table](id:2)", "datasource_access"), ] + # do a downgrade downgrade_catalog_perms() + session.commit() # revert assert dataset.catalog is None @@ -135,15 +185,25 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None: assert saved_query.catalog is None assert tab_state.catalog is None assert table_schema.catalog is None + assert dataset.catalog_perm is None assert dataset.schema_perm == "[my_db].[public]" + assert chart.catalog_perm is None assert chart.schema_perm == "[my_db].[public]" - assert session.query(ViewMenu.name).all() == [ - ("[my_db].(id:1)",), - ("[my_db].[my_table](id:1)",), - ("[my_db].[public]",), - ("[my_db].[db]",), + assert ( + session.query(ViewMenu.name, Permission.name) + .join(PermissionView, ViewMenu.id == PermissionView.view_menu_id) + .join(Permission, PermissionView.permission_id == Permission.id) + .all() + ) == [ + ("[my_db].(id:1)", "database_access"), + ("[my_db].[my_table](id:1)", "datasource_access"), + ("[my_db].[public]", "schema_access"), ] + # make sure new dataset/chart were deleted + assert session.query(SqlaTable).all() == [dataset] + assert session.query(Slice).all() == [chart] + def test_upgrade_catalog_perms_graceful( mocker: MockerFixture, @@ -236,6 +296,7 @@ def test_upgrade_catalog_perms_graceful( ] upgrade_catalog_perms() + session.commit() # after migration assert dataset.catalog == "db" @@ -253,6 +314,7 @@ def test_upgrade_catalog_perms_graceful( ] downgrade_catalog_perms() + session.commit() # revert assert dataset.catalog is None @@ -266,5 +328,4 @@ def test_upgrade_catalog_perms_graceful( ("[my_db].(id:1)",), ("[my_db].[my_table](id:1)",), ("[my_db].[public]",), - ("[my_db].[db]",), ]