Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add database filter for the assessment workflow #989

Merged
merged 15 commits into from
Mar 7, 2024
Merged
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ access the configuration file from the command line. Here's the description of c
* `spark_conf`: An optional dictionary of Spark configuration properties.
* `override_clusters`: An optional dictionary mapping job cluster names to existing cluster IDs.
* `policy_id`: An optional string representing the ID of the cluster policy.
* `include_databases`: An optional list of strings representing the names of databases to include for migration.

[[back to top](#databricks-labs-ucx)]

Expand Down
3 changes: 3 additions & 0 deletions src/databricks/labs/ucx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class WorkspaceConfig: # pylint: disable=too-many-instance-attributes
# Flag to see if terraform has been used for deploying certain entities
is_terraform_used: bool = False

# Whether the assessment should capture a specific list of databases, if not specified, it will list all databases.
include_databases: list[str] | None = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


def replace_inventory_variable(self, text: str) -> str:
return text.replace("$inventory", f"hive_metastore.{self.inventory_database}")

Expand Down
13 changes: 9 additions & 4 deletions src/databricks/labs/ucx/hive_metastore/grants.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,14 @@ def uc_grant_sql(self):


class GrantsCrawler(CrawlerBase[Grant]):
def __init__(self, tc: TablesCrawler, udf: UdfsCrawler):
def __init__(self, tc: TablesCrawler, udf: UdfsCrawler, include_databases: list[str] | None = None):
assert tc._backend == udf._backend
assert tc._catalog == udf._catalog
assert tc._schema == udf._schema
super().__init__(tc._backend, tc._catalog, tc._schema, "grants", Grant)
self._tc = tc
self._udf = udf
self._include_databases = include_databases

def snapshot(self) -> Iterable[Grant]:
return self._snapshot(partial(self._try_load), partial(self._crawl))
Expand Down Expand Up @@ -189,9 +190,13 @@ def _crawl(self) -> Iterable[Grant]:
# Scanning ANY FILE and ANONYMOUS FUNCTION grants
tasks.append(partial(self.grants, catalog=catalog, any_file=True))
tasks.append(partial(self.grants, catalog=catalog, anonymous_function=True))
# scan all databases, even empty ones
for row in self._fetch(f"SHOW DATABASES FROM {escape_sql_identifier(catalog)}"):
tasks.append(partial(self.grants, catalog=catalog, database=row.databaseName))
if not self._include_databases:
# scan all databases, even empty ones
for row in self._fetch(f"SHOW DATABASES FROM {escape_sql_identifier(catalog)}"):
tasks.append(partial(self.grants, catalog=catalog, database=row.databaseName))
else:
for database in self._include_databases:
tasks.append(partial(self.grants, catalog=catalog, database=database))
for table in self._tc.snapshot():
fn = partial(self.grants, catalog=catalog, database=table.database)
# views are recognized as tables
Expand Down
4 changes: 2 additions & 2 deletions src/databricks/labs/ucx/hive_metastore/table_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TableSize:


class TableSizeCrawler(CrawlerBase):
def __init__(self, backend: SqlBackend, schema):
def __init__(self, backend: SqlBackend, schema, include_databases: list[str] | None = None):
"""
Initializes a TablesSizeCrawler instance.

Expand All @@ -31,7 +31,7 @@ def __init__(self, backend: SqlBackend, schema):

self._backend = backend
super().__init__(backend, "hive_metastore", schema, "table_size", TableSize)
self._tables_crawler = TablesCrawler(backend, schema)
self._tables_crawler = TablesCrawler(backend, schema, include_databases)
self._spark = SparkSession.builder.getOrCreate()

def _crawl(self) -> Iterable[TableSize]:
Expand Down
14 changes: 8 additions & 6 deletions src/databricks/labs/ucx/hive_metastore/tables.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import re
import typing
from collections.abc import Iterable, Iterator
from collections.abc import Iterable
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
Expand All @@ -10,7 +10,6 @@

from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend
from databricks.labs.ucx.framework.utils import escape_sql_identifier
from databricks.labs.ucx.mixins.sql import Row

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -158,7 +157,7 @@ class MigrationCount:


class TablesCrawler(CrawlerBase):
def __init__(self, backend: SqlBackend, schema):
def __init__(self, backend: SqlBackend, schema, include_databases: list[str] | None = None):
"""
Initializes a TablesCrawler instance.

Expand All @@ -167,9 +166,12 @@ def __init__(self, backend: SqlBackend, schema):
schema: The schema name for the inventory persistence.
"""
super().__init__(backend, "hive_metastore", schema, "tables", Table)
self._include_database = include_databases

def _all_databases(self) -> Iterator[Row]:
yield from self._fetch("SHOW DATABASES")
def _all_databases(self) -> list[str]:
if not self._include_database:
return [row[0] for row in self._fetch("SHOW DATABASES")]
return self._include_database

def snapshot(self) -> list[Table]:
"""
Expand Down Expand Up @@ -215,7 +217,7 @@ def _crawl(self) -> Iterable[Table]:
"""
tasks = []
catalog = "hive_metastore"
for (database,) in self._all_databases():
for database in self._all_databases():
logger.debug(f"[{catalog}.{database}] listing tables")
for _, table, _is_tmp in self._fetch(
f"SHOW TABLES FROM {escape_sql_identifier(catalog)}.{escape_sql_identifier(database)}"
Expand Down
25 changes: 18 additions & 7 deletions src/databricks/labs/ucx/hive_metastore/tables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import org.apache.spark.sql.functions.{col,lower,upper}
// must follow the same structure as databricks.labs.ucx.hive_metastore.tables.Table
case class TableDetails(catalog: String, database: String, name: String, object_type: String,
table_format: String, location: String, view_text: String, upgraded_to: String, storage_properties: String)

// recording error log in the database
case class TableError(catalog: String, database: String, name: String, error: String)

Expand Down Expand Up @@ -72,20 +71,32 @@ def metadataForAllTables(databases: Seq[String], queue: ConcurrentLinkedQueue[Ta
}).toList.toDF
}

def getInventoryDatabase(): String={
def getConfig(): java.util.Map[String, Any] = {
dbutils.widgets.text("config", "./config.yml")
val configFile = dbutils.widgets.get("config")
val fs = FileSystem.get(new java.net.URI("file:/Workspace"), sc.hadoopConfiguration)
val file = fs.open(new Path(configFile))
val configContents = org.apache.commons.io.IOUtils.toString(file, java.nio.charset.StandardCharsets.UTF_8)
val configObj = new Yaml().load(configContents).asInstanceOf[java.util.Map[String, Any]]
val inventoryDatabase = configObj.get("inventory_database").toString()
return inventoryDatabase
return new Yaml().load(configContents).asInstanceOf[java.util.Map[String, Any]]
}

def getInventoryDatabase(configObj:java.util.Map[String, Any]): String ={
return configObj.get("inventory_database").toString()
}

def getIncludeDatabases(configObj:java.util.Map[String, Any], inventoryDatabase:String): Seq[String] ={
val includeDatabases = JavaConverters.asScalaBuffer(configObj.getOrDefault("include_databases",new java.util.ArrayList[String]()).asInstanceOf[java.util.ArrayList[String]]).toList

if (includeDatabases.isEmpty) {
return spark.sharedState.externalCatalog.listDatabases().filter(_ != s"$inventoryDatabase")
}
return spark.sharedState.externalCatalog.listDatabases().filter(includeDatabases.contains(_))
}

val inventoryDatabase = getInventoryDatabase()
var df = metadataForAllTables(spark.sharedState.externalCatalog.listDatabases().filter(_ != s"$inventoryDatabase"), failures)
val config = getConfig()
val inventoryDatabase = getInventoryDatabase(config)
val includeDatabases = getIncludeDatabases(config, inventoryDatabase)
var df = metadataForAllTables(includeDatabases, failures)
var columnsToMapLower = Array("catalog","database","name","upgraded_to","storage_properties")
columnsToMapLower.map(column => {
df = df.withColumn(column, lower(col(column)))
Expand Down
14 changes: 8 additions & 6 deletions src/databricks/labs/ucx/hive_metastore/udfs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from collections.abc import Iterable, Iterator
from collections.abc import Iterable
from dataclasses import dataclass
from functools import partial

Expand All @@ -8,7 +8,6 @@

from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend
from databricks.labs.ucx.framework.utils import escape_sql_identifier
from databricks.labs.ucx.mixins.sql import Row

logger = logging.getLogger(__name__)

Expand All @@ -32,7 +31,7 @@ def key(self) -> str:


class UdfsCrawler(CrawlerBase):
def __init__(self, backend: SqlBackend, schema):
def __init__(self, backend: SqlBackend, schema: str, include_databases: list[str] | None = None):
"""
Initializes a UdfsCrawler instance.

Expand All @@ -41,9 +40,12 @@ def __init__(self, backend: SqlBackend, schema):
schema: The schema name for the inventory persistence.
"""
super().__init__(backend, "hive_metastore", schema, "udfs", Udf)
self._include_database = include_databases

def _all_databases(self) -> Iterator[Row]:
yield from self._fetch("SHOW DATABASES")
def _all_databases(self) -> list[str]:
if not self._include_database:
return [row[0] for row in self._fetch("SHOW DATABASES")]
return self._include_database

def snapshot(self) -> list[Udf]:
"""
Expand All @@ -66,7 +68,7 @@ def _crawl(self) -> Iterable[Udf]:
# need to set the current catalog otherwise "SHOW USER FUNCTIONS FROM" is raising error:
# "target schema <database> is not in the current catalog"
self._exec(f"USE CATALOG {escape_sql_identifier(catalog)};")
for (database,) in self._all_databases():
for database in self._all_databases():
for task in self._collect_tasks(catalog, database):
tasks.append(task)
catalog_tables, errors = Threads.gather(f"listing udfs in {catalog}", tasks)
Expand Down
12 changes: 12 additions & 0 deletions src/databricks/labs/ucx/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def _configure_new_installation(self) -> WorkspaceConfig:
warehouse_id = self._configure_warehouse()
configure_groups = ConfigureGroups(self._prompts)
configure_groups.run()

log_level = self._prompts.question("Log level", default="INFO").upper()
num_threads = int(self._prompts.question("Number of threads", default="8", valid_number=True))

Expand Down Expand Up @@ -275,13 +276,24 @@ def _configure_new_installation(self) -> WorkspaceConfig:
spark_conf=spark_conf_dict,
policy_id=policy_id,
is_terraform_used=is_terraform_used,
include_databases=self._select_databases(),
)
self._installation.save(config)
ws_file_url = self._installation.workspace_link(config.__file__)
if self._prompts.confirm(f"Open config file in the browser and continue installing? {ws_file_url}"):
webbrowser.open(ws_file_url)
return config

def _select_databases(self):
selected_databases = self._prompts.question(
"Comma-separated list of databases to migrate. If not specified, we'll use all "
"databases in hive_metastore",
default="<ALL>",
)
if selected_databases != "<ALL>":
return [x.strip() for x in selected_databases.split(",")]
return None

def _configure_warehouse(self):
def warehouse_type(_):
return _.warehouse_type.value if not _.enable_serverless_compute else "SERVERLESS"
Expand Down
6 changes: 3 additions & 3 deletions src/databricks/labs/ucx/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def crawl_grants(cfg: WorkspaceConfig, _: WorkspaceClient, sql_backend: SqlBacke

Note: This job runs on a separate cluster (named `tacl`) as it requires the proper configuration to have the Table
ACLs enabled and available for retrieval."""
tables = TablesCrawler(sql_backend, cfg.inventory_database)
udfs = UdfsCrawler(sql_backend, cfg.inventory_database)
grants = GrantsCrawler(tables, udfs)
tables = TablesCrawler(sql_backend, cfg.inventory_database, cfg.include_databases)
udfs = UdfsCrawler(sql_backend, cfg.inventory_database, cfg.include_databases)
grants = GrantsCrawler(tables, udfs, cfg.include_databases)
grants.snapshot()


Expand Down
7 changes: 5 additions & 2 deletions tests/integration/hive_metastore/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@ def test_describe_all_tables_in_databases(ws, sql_backend, inventory_schema, mak
f"view={view.full_name}"
)

tables = TablesCrawler(sql_backend, inventory_schema)
schema_c = make_schema(catalog_name="hive_metastore")
make_table(schema_name=schema_c.name)

tables = TablesCrawler(sql_backend, inventory_schema, [schema_a.name, schema_b.name])
william-conti marked this conversation as resolved.
Show resolved Hide resolved

all_tables = {}
for table in tables.snapshot():
all_tables[table.key] = table

assert len(all_tables) >= 5
assert len(all_tables) == 5
assert all_tables[non_delta.full_name].table_format == "JSON"
assert all_tables[non_delta.full_name].what == What.DB_DATASET
assert all_tables[managed_table.full_name].object_type == "MANAGED"
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/hive_metastore/test_udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
@retried(on=[NotFound], timeout=timedelta(minutes=2))
def test_describe_all_udfs_in_databases(ws, sql_backend, inventory_schema, make_schema, make_udf):
schema_a = make_schema(catalog_name="hive_metastore")
schema_b = make_schema(catalog_name="hive_metastore")
make_schema(catalog_name="hive_metastore")
udf_a = make_udf(schema_name=schema_a.name)
udf_b = make_udf(schema_name=schema_a.name)
make_udf(schema_name=schema_b.name)

udfs_crawler = UdfsCrawler(sql_backend, inventory_schema)
udfs_crawler = UdfsCrawler(sql_backend, inventory_schema, [schema_a.name, schema_b.name])
actual_grants = udfs_crawler.snapshot()

unique_udf_grants = {
Expand Down
34 changes: 34 additions & 0 deletions tests/unit/hive_metastore/test_grants.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,3 +429,37 @@ def test_udf_grants_returning_error_when_describing():
anonymous_function=False,
)
]


def test_crawler_should_filter_databases():
sql_backend = MockBackend(
rows={
"SHOW TABLES FROM hive_metastore.database_one": [
("database_one", "table_one", "true"),
("database_one", "table_two", "true"),
],
"SELECT * FROM hive_metastore.schema.tables": [
make_row(("foo", "bar", "test_table", "type", "DELTA", "/foo/bar/test", None), SELECT_COLS),
make_row(("foo", "bar", "test_view", "type", "VIEW", None, "SELECT * FROM table"), SELECT_COLS),
make_row(("foo", None, None, "type", "CATALOG", None, None), SELECT_COLS),
],
"DESCRIBE TABLE EXTENDED hive_metastore.database_one.*": [
make_row(("Catalog", "foo", "ignored"), DESCRIBE_COLS),
make_row(("Type", "TABLE", "ignored"), DESCRIBE_COLS),
make_row(("Provider", "", "ignored"), DESCRIBE_COLS),
make_row(("Location", "/foo/bar/test", "ignored"), DESCRIBE_COLS),
make_row(("View Text", "SELECT * FROM table", "ignored"), DESCRIBE_COLS),
],
"SHOW GRANTS ON .*": [
make_row(("princ1", "SELECT", "TABLE", "ignored"), SHOW_COLS),
make_row(("princ1", "SELECT", "VIEW", "ignored"), SHOW_COLS),
make_row(("princ1", "USE", "CATALOG$", "ignored"), SHOW_COLS),
],
}
)
table = TablesCrawler(sql_backend, "schema", include_databases=["database_one"])
udf = UdfsCrawler(sql_backend, "schema", include_databases=["database_one"])
crawler = GrantsCrawler(table, udf, include_databases=["database_one"])
grants = crawler.snapshot()
assert len(grants) == 3
assert 'SHOW TABLES FROM hive_metastore.database_one' in sql_backend.queries
17 changes: 17 additions & 0 deletions tests/unit/hive_metastore/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,20 @@ def test_is_supported_for_sync(table, supported):
)
def test_table_what(table, what):
assert table.what == what


def test_tables_crawler_should_filter_by_database():
rows = {
"SHOW TABLES FROM hive_metastore.database": [("", "table1", ""), ("", "table2", "")],
"SHOW TABLES FROM hive_metastore.database_2": [("", "table1", "")],
}
backend = MockBackend(rows=rows)
tables_crawler = TablesCrawler(backend, "default", ["database"])
results = tables_crawler.snapshot()
assert len(results) == 2
william-conti marked this conversation as resolved.
Show resolved Hide resolved
assert backend.queries == [
'SELECT * FROM hive_metastore.default.tables',
'SHOW TABLES FROM hive_metastore.database',
'DESCRIBE TABLE EXTENDED hive_metastore.database.table1',
'DESCRIBE TABLE EXTENDED hive_metastore.database.table2',
]
12 changes: 12 additions & 0 deletions tests/unit/hive_metastore/test_udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,15 @@ def test_udfs_returning_error_when_describing():
udf_crawler = UdfsCrawler(backend, "default")
results = udf_crawler.snapshot()
assert len(results) == 0


def test_tables_crawler_should_filter_by_database():
rows = {
"SHOW USER FUNCTIONS FROM hive_metastore.database": [
make_row(("hive_metastore.database.function1",), ["function"]),
],
}
backend = MockBackend(rows=rows)
udf_crawler = UdfsCrawler(backend, "default", ["database"])
results = udf_crawler.snapshot()
assert len(results) == 1
Loading
Loading