Skip to content

Commit

Permalink
Add database filter for the assessment workflow (#989)
Browse files Browse the repository at this point in the history
## Changes
Added a database filter to allow the user to not crawl all the databases
in Hive Metastore.

### Linked issues
Resolves #937

### Functionality 

- [x] modified assessment

### Tests

- [X] manually tested
- [X] added unit tests
- [X] added integration tests
- [X] verified on staging environment (screenshot attached)
  • Loading branch information
william-conti authored and dmoore247 committed Mar 23, 2024
1 parent 03d5657 commit cbe8c09
Show file tree
Hide file tree
Showing 16 changed files with 180 additions and 32 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ access the configuration file from the command line. Here's the description of c
* `spark_conf`: An optional dictionary of Spark configuration properties.
* `override_clusters`: An optional dictionary mapping job cluster names to existing cluster IDs.
* `policy_id`: An optional string representing the ID of the cluster policy.
* `include_databases`: An optional list of strings representing the names of databases to include for migration.

[[back to top](#databricks-labs-ucx)]

Expand Down
3 changes: 3 additions & 0 deletions src/databricks/labs/ucx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class WorkspaceConfig: # pylint: disable=too-many-instance-attributes
# Flag to see if terraform has been used for deploying certain entities
is_terraform_used: bool = False

# Whether the assessment should capture a specific list of databases, if not specified, it will list all databases.
include_databases: list[str] | None = None

def replace_inventory_variable(self, text: str) -> str:
return text.replace("$inventory", f"hive_metastore.{self.inventory_database}")

Expand Down
13 changes: 9 additions & 4 deletions src/databricks/labs/ucx/hive_metastore/grants.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,14 @@ def uc_grant_sql(self):


class GrantsCrawler(CrawlerBase[Grant]):
def __init__(self, tc: TablesCrawler, udf: UdfsCrawler):
def __init__(self, tc: TablesCrawler, udf: UdfsCrawler, include_databases: list[str] | None = None):
assert tc._backend == udf._backend
assert tc._catalog == udf._catalog
assert tc._schema == udf._schema
super().__init__(tc._backend, tc._catalog, tc._schema, "grants", Grant)
self._tc = tc
self._udf = udf
self._include_databases = include_databases

def snapshot(self) -> Iterable[Grant]:
return self._snapshot(partial(self._try_load), partial(self._crawl))
Expand Down Expand Up @@ -189,9 +190,13 @@ def _crawl(self) -> Iterable[Grant]:
# Scanning ANY FILE and ANONYMOUS FUNCTION grants
tasks.append(partial(self.grants, catalog=catalog, any_file=True))
tasks.append(partial(self.grants, catalog=catalog, anonymous_function=True))
# scan all databases, even empty ones
for row in self._fetch(f"SHOW DATABASES FROM {escape_sql_identifier(catalog)}"):
tasks.append(partial(self.grants, catalog=catalog, database=row.databaseName))
if not self._include_databases:
# scan all databases, even empty ones
for row in self._fetch(f"SHOW DATABASES FROM {escape_sql_identifier(catalog)}"):
tasks.append(partial(self.grants, catalog=catalog, database=row.databaseName))
else:
for database in self._include_databases:
tasks.append(partial(self.grants, catalog=catalog, database=database))
for table in self._tc.snapshot():
fn = partial(self.grants, catalog=catalog, database=table.database)
# views are recognized as tables
Expand Down
4 changes: 2 additions & 2 deletions src/databricks/labs/ucx/hive_metastore/table_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TableSize:


class TableSizeCrawler(CrawlerBase):
def __init__(self, backend: SqlBackend, schema):
def __init__(self, backend: SqlBackend, schema, include_databases: list[str] | None = None):
"""
Initializes a TablesSizeCrawler instance.
Expand All @@ -31,7 +31,7 @@ def __init__(self, backend: SqlBackend, schema):

self._backend = backend
super().__init__(backend, "hive_metastore", schema, "table_size", TableSize)
self._tables_crawler = TablesCrawler(backend, schema)
self._tables_crawler = TablesCrawler(backend, schema, include_databases)
self._spark = SparkSession.builder.getOrCreate()

def _crawl(self) -> Iterable[TableSize]:
Expand Down
14 changes: 8 additions & 6 deletions src/databricks/labs/ucx/hive_metastore/tables.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import re
import typing
from collections.abc import Iterable, Iterator
from collections.abc import Iterable
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
Expand All @@ -10,7 +10,6 @@

from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend
from databricks.labs.ucx.framework.utils import escape_sql_identifier
from databricks.labs.ucx.mixins.sql import Row

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -158,7 +157,7 @@ class MigrationCount:


class TablesCrawler(CrawlerBase):
def __init__(self, backend: SqlBackend, schema):
def __init__(self, backend: SqlBackend, schema, include_databases: list[str] | None = None):
"""
Initializes a TablesCrawler instance.
Expand All @@ -167,9 +166,12 @@ def __init__(self, backend: SqlBackend, schema):
schema: The schema name for the inventory persistence.
"""
super().__init__(backend, "hive_metastore", schema, "tables", Table)
self._include_database = include_databases

def _all_databases(self) -> Iterator[Row]:
yield from self._fetch("SHOW DATABASES")
def _all_databases(self) -> list[str]:
if not self._include_database:
return [row[0] for row in self._fetch("SHOW DATABASES")]
return self._include_database

def snapshot(self) -> list[Table]:
"""
Expand Down Expand Up @@ -215,7 +217,7 @@ def _crawl(self) -> Iterable[Table]:
"""
tasks = []
catalog = "hive_metastore"
for (database,) in self._all_databases():
for database in self._all_databases():
logger.debug(f"[{catalog}.{database}] listing tables")
for _, table, _is_tmp in self._fetch(
f"SHOW TABLES FROM {escape_sql_identifier(catalog)}.{escape_sql_identifier(database)}"
Expand Down
25 changes: 18 additions & 7 deletions src/databricks/labs/ucx/hive_metastore/tables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import org.apache.spark.sql.functions.{col,lower,upper}
// must follow the same structure as databricks.labs.ucx.hive_metastore.tables.Table
case class TableDetails(catalog: String, database: String, name: String, object_type: String,
table_format: String, location: String, view_text: String, upgraded_to: String, storage_properties: String)

// recording error log in the database
case class TableError(catalog: String, database: String, name: String, error: String)

Expand Down Expand Up @@ -72,20 +71,32 @@ def metadataForAllTables(databases: Seq[String], queue: ConcurrentLinkedQueue[Ta
}).toList.toDF
}

def getInventoryDatabase(): String={
def getConfig(): java.util.Map[String, Any] = {
dbutils.widgets.text("config", "./config.yml")
val configFile = dbutils.widgets.get("config")
val fs = FileSystem.get(new java.net.URI("file:/Workspace"), sc.hadoopConfiguration)
val file = fs.open(new Path(configFile))
val configContents = org.apache.commons.io.IOUtils.toString(file, java.nio.charset.StandardCharsets.UTF_8)
val configObj = new Yaml().load(configContents).asInstanceOf[java.util.Map[String, Any]]
val inventoryDatabase = configObj.get("inventory_database").toString()
return inventoryDatabase
return new Yaml().load(configContents).asInstanceOf[java.util.Map[String, Any]]
}

def getInventoryDatabase(configObj:java.util.Map[String, Any]): String ={
return configObj.get("inventory_database").toString()
}

def getIncludeDatabases(configObj:java.util.Map[String, Any], inventoryDatabase:String): Seq[String] ={
val includeDatabases = JavaConverters.asScalaBuffer(configObj.getOrDefault("include_databases",new java.util.ArrayList[String]()).asInstanceOf[java.util.ArrayList[String]]).toList

if (includeDatabases.isEmpty) {
return spark.sharedState.externalCatalog.listDatabases().filter(_ != s"$inventoryDatabase")
}
return spark.sharedState.externalCatalog.listDatabases().filter(includeDatabases.contains(_))
}

val inventoryDatabase = getInventoryDatabase()
var df = metadataForAllTables(spark.sharedState.externalCatalog.listDatabases().filter(_ != s"$inventoryDatabase"), failures)
val config = getConfig()
val inventoryDatabase = getInventoryDatabase(config)
val includeDatabases = getIncludeDatabases(config, inventoryDatabase)
var df = metadataForAllTables(includeDatabases, failures)
var columnsToMapLower = Array("catalog","database","name","upgraded_to","storage_properties")
columnsToMapLower.map(column => {
df = df.withColumn(column, lower(col(column)))
Expand Down
14 changes: 8 additions & 6 deletions src/databricks/labs/ucx/hive_metastore/udfs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from collections.abc import Iterable, Iterator
from collections.abc import Iterable
from dataclasses import dataclass
from functools import partial

Expand All @@ -8,7 +8,6 @@

from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend
from databricks.labs.ucx.framework.utils import escape_sql_identifier
from databricks.labs.ucx.mixins.sql import Row

logger = logging.getLogger(__name__)

Expand All @@ -32,7 +31,7 @@ def key(self) -> str:


class UdfsCrawler(CrawlerBase):
def __init__(self, backend: SqlBackend, schema):
def __init__(self, backend: SqlBackend, schema: str, include_databases: list[str] | None = None):
"""
Initializes a UdfsCrawler instance.
Expand All @@ -41,9 +40,12 @@ def __init__(self, backend: SqlBackend, schema):
schema: The schema name for the inventory persistence.
"""
super().__init__(backend, "hive_metastore", schema, "udfs", Udf)
self._include_database = include_databases

def _all_databases(self) -> Iterator[Row]:
yield from self._fetch("SHOW DATABASES")
def _all_databases(self) -> list[str]:
if not self._include_database:
return [row[0] for row in self._fetch("SHOW DATABASES")]
return self._include_database

def snapshot(self) -> list[Udf]:
"""
Expand All @@ -66,7 +68,7 @@ def _crawl(self) -> Iterable[Udf]:
# need to set the current catalog otherwise "SHOW USER FUNCTIONS FROM" is raising error:
# "target schema <database> is not in the current catalog"
self._exec(f"USE CATALOG {escape_sql_identifier(catalog)};")
for (database,) in self._all_databases():
for database in self._all_databases():
for task in self._collect_tasks(catalog, database):
tasks.append(task)
catalog_tables, errors = Threads.gather(f"listing udfs in {catalog}", tasks)
Expand Down
12 changes: 12 additions & 0 deletions src/databricks/labs/ucx/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def _configure_new_installation(self) -> WorkspaceConfig:
warehouse_id = self._configure_warehouse()
configure_groups = ConfigureGroups(self._prompts)
configure_groups.run()

log_level = self._prompts.question("Log level", default="INFO").upper()
num_threads = int(self._prompts.question("Number of threads", default="8", valid_number=True))

Expand Down Expand Up @@ -275,13 +276,24 @@ def _configure_new_installation(self) -> WorkspaceConfig:
spark_conf=spark_conf_dict,
policy_id=policy_id,
is_terraform_used=is_terraform_used,
include_databases=self._select_databases(),
)
self._installation.save(config)
ws_file_url = self._installation.workspace_link(config.__file__)
if self._prompts.confirm(f"Open config file in the browser and continue installing? {ws_file_url}"):
webbrowser.open(ws_file_url)
return config

def _select_databases(self):
selected_databases = self._prompts.question(
"Comma-separated list of databases to migrate. If not specified, we'll use all "
"databases in hive_metastore",
default="<ALL>",
)
if selected_databases != "<ALL>":
return [x.strip() for x in selected_databases.split(",")]
return None

def _configure_warehouse(self):
def warehouse_type(_):
return _.warehouse_type.value if not _.enable_serverless_compute else "SERVERLESS"
Expand Down
6 changes: 3 additions & 3 deletions src/databricks/labs/ucx/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def crawl_grants(cfg: WorkspaceConfig, _: WorkspaceClient, sql_backend: SqlBacke
Note: This job runs on a separate cluster (named `tacl`) as it requires the proper configuration to have the Table
ACLs enabled and available for retrieval."""
tables = TablesCrawler(sql_backend, cfg.inventory_database)
udfs = UdfsCrawler(sql_backend, cfg.inventory_database)
grants = GrantsCrawler(tables, udfs)
tables = TablesCrawler(sql_backend, cfg.inventory_database, cfg.include_databases)
udfs = UdfsCrawler(sql_backend, cfg.inventory_database, cfg.include_databases)
grants = GrantsCrawler(tables, udfs, cfg.include_databases)
grants.snapshot()


Expand Down
7 changes: 5 additions & 2 deletions tests/integration/hive_metastore/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@ def test_describe_all_tables_in_databases(ws, sql_backend, inventory_schema, mak
f"view={view.full_name}"
)

tables = TablesCrawler(sql_backend, inventory_schema)
schema_c = make_schema(catalog_name="hive_metastore")
make_table(schema_name=schema_c.name)

tables = TablesCrawler(sql_backend, inventory_schema, [schema_a.name, schema_b.name])

all_tables = {}
for table in tables.snapshot():
all_tables[table.key] = table

assert len(all_tables) >= 5
assert len(all_tables) == 5
assert all_tables[non_delta.full_name].table_format == "JSON"
assert all_tables[non_delta.full_name].what == What.DB_DATASET
assert all_tables[managed_table.full_name].object_type == "MANAGED"
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/hive_metastore/test_udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
@retried(on=[NotFound], timeout=timedelta(minutes=2))
def test_describe_all_udfs_in_databases(ws, sql_backend, inventory_schema, make_schema, make_udf):
schema_a = make_schema(catalog_name="hive_metastore")
schema_b = make_schema(catalog_name="hive_metastore")
make_schema(catalog_name="hive_metastore")
udf_a = make_udf(schema_name=schema_a.name)
udf_b = make_udf(schema_name=schema_a.name)
make_udf(schema_name=schema_b.name)

udfs_crawler = UdfsCrawler(sql_backend, inventory_schema)
udfs_crawler = UdfsCrawler(sql_backend, inventory_schema, [schema_a.name, schema_b.name])
actual_grants = udfs_crawler.snapshot()

unique_udf_grants = {
Expand Down
34 changes: 34 additions & 0 deletions tests/unit/hive_metastore/test_grants.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,3 +429,37 @@ def test_udf_grants_returning_error_when_describing():
anonymous_function=False,
)
]


def test_crawler_should_filter_databases():
sql_backend = MockBackend(
rows={
"SHOW TABLES FROM hive_metastore.database_one": [
("database_one", "table_one", "true"),
("database_one", "table_two", "true"),
],
"SELECT * FROM hive_metastore.schema.tables": [
make_row(("foo", "bar", "test_table", "type", "DELTA", "/foo/bar/test", None), SELECT_COLS),
make_row(("foo", "bar", "test_view", "type", "VIEW", None, "SELECT * FROM table"), SELECT_COLS),
make_row(("foo", None, None, "type", "CATALOG", None, None), SELECT_COLS),
],
"DESCRIBE TABLE EXTENDED hive_metastore.database_one.*": [
make_row(("Catalog", "foo", "ignored"), DESCRIBE_COLS),
make_row(("Type", "TABLE", "ignored"), DESCRIBE_COLS),
make_row(("Provider", "", "ignored"), DESCRIBE_COLS),
make_row(("Location", "/foo/bar/test", "ignored"), DESCRIBE_COLS),
make_row(("View Text", "SELECT * FROM table", "ignored"), DESCRIBE_COLS),
],
"SHOW GRANTS ON .*": [
make_row(("princ1", "SELECT", "TABLE", "ignored"), SHOW_COLS),
make_row(("princ1", "SELECT", "VIEW", "ignored"), SHOW_COLS),
make_row(("princ1", "USE", "CATALOG$", "ignored"), SHOW_COLS),
],
}
)
table = TablesCrawler(sql_backend, "schema", include_databases=["database_one"])
udf = UdfsCrawler(sql_backend, "schema", include_databases=["database_one"])
crawler = GrantsCrawler(table, udf, include_databases=["database_one"])
grants = crawler.snapshot()
assert len(grants) == 3
assert 'SHOW TABLES FROM hive_metastore.database_one' in sql_backend.queries
17 changes: 17 additions & 0 deletions tests/unit/hive_metastore/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,20 @@ def test_is_supported_for_sync(table, supported):
)
def test_table_what(table, what):
assert table.what == what


def test_tables_crawler_should_filter_by_database():
rows = {
"SHOW TABLES FROM hive_metastore.database": [("", "table1", ""), ("", "table2", "")],
"SHOW TABLES FROM hive_metastore.database_2": [("", "table1", "")],
}
backend = MockBackend(rows=rows)
tables_crawler = TablesCrawler(backend, "default", ["database"])
results = tables_crawler.snapshot()
assert len(results) == 2
assert backend.queries == [
'SELECT * FROM hive_metastore.default.tables',
'SHOW TABLES FROM hive_metastore.database',
'DESCRIBE TABLE EXTENDED hive_metastore.database.table1',
'DESCRIBE TABLE EXTENDED hive_metastore.database.table2',
]
12 changes: 12 additions & 0 deletions tests/unit/hive_metastore/test_udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,15 @@ def test_udfs_returning_error_when_describing():
udf_crawler = UdfsCrawler(backend, "default")
results = udf_crawler.snapshot()
assert len(results) == 0


def test_tables_crawler_should_filter_by_database():
rows = {
"SHOW USER FUNCTIONS FROM hive_metastore.database": [
make_row(("hive_metastore.database.function1",), ["function"]),
],
}
backend = MockBackend(rows=rows)
udf_crawler = UdfsCrawler(backend, "default", ["database"])
results = udf_crawler.snapshot()
assert len(results) == 1
Loading

0 comments on commit cbe8c09

Please sign in to comment.