Skip to content

Commit

Permalink
Refactor _get_catalog_schema_grants
Browse files Browse the repository at this point in the history
  • Loading branch information
JCZuurmond committed Apr 29, 2024
1 parent ef1ea46 commit f3fe2a2
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions src/databricks/labs/ucx/hive_metastore/catalog_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,25 @@ def _update_principal_acl(self):
logger.debug(f"Migrating acls on {grant.this_type_and_key()} using SQL query: {acl_migrate_sql}")
self._backend.execute(acl_migrate_sql)

def _get_catalog_schema_grants(self):
catalog_grants: set[Grant] = set()
new_grants = []
src_trg_schema_mapping = self._get_database_source_target_mapping()
grants = self._principal_grants.get_interactive_cluster_grants()
# filter on grants to only get database level grants
database_grants = [grant for grant in grants if grant.table is None and grant.view is None]
for db_grant in database_grants:
for target_catalog, target_schema in src_trg_schema_mapping[db_grant.database]:
new_grants.append(dataclasses.replace(db_grant, catalog=target_catalog, database=target_schema))
for grant in new_grants:
catalog_grants.add(dataclasses.replace(grant, database=None))
new_grants.extend(catalog_grants)
return new_grants
def _get_catalog_schema_grants(self) -> list[Grant]:
database_grants = []
source_to_target_schema_mapping = self._get_database_source_target_mapping()
for grant in self._principal_grants.get_interactive_cluster_grants():
if grant.table is not None or grant.view is not None: # Filter for database/schema grants
continue
if grant.database is None:
continue

for target_catalog, target_schema in source_to_target_schema_mapping[grant.database]:
database_grants.append(dataclasses.replace(grant, catalog=target_catalog, database=target_schema))

catalog_grants, seen_catalogs = [], set()
for grant in database_grants:
if grant.catalog not in seen_catalogs:
catalog_grants.append(dataclasses.replace(grant, database=None))
seen_catalogs.add(grant.catalog)

return catalog_grants + database_grants

def _get_database_source_target_mapping(self) -> dict[str, set[tuple[str, str]]]:
"""Generate a dictionary of source database in hive_metastore and its
Expand Down

0 comments on commit f3fe2a2

Please sign in to comment.