Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add database filter for the assessment workflow #989

Merged
merged 15 commits into from
Mar 7, 2024
Merged
3 changes: 3 additions & 0 deletions src/databricks/labs/ucx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class WorkspaceConfig: # pylint: disable=too-many-instance-attributes
override_clusters: dict[str, str] | None = None
policy_id: str | None = None

# Whether the assessment should capture a specific list of databases, if not specified, it will list all databases.
include_databases: list[str] | None = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


def replace_inventory_variable(self, text: str) -> str:
return text.replace("$inventory", f"hive_metastore.{self.inventory_database}")

Expand Down
13 changes: 9 additions & 4 deletions src/databricks/labs/ucx/hive_metastore/grants.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,14 @@


class GrantsCrawler(CrawlerBase[Grant]):
def __init__(self, tc: TablesCrawler, udf: UdfsCrawler):
def __init__(self, tc: TablesCrawler, udf: UdfsCrawler, include_databases: list[str] | None = None):
assert tc._backend == udf._backend
assert tc._catalog == udf._catalog
assert tc._schema == udf._schema
super().__init__(tc._backend, tc._catalog, tc._schema, "grants", Grant)
self._tc = tc
self._udf = udf
self._include_databases = include_databases

def snapshot(self) -> Iterable[Grant]:
return self._snapshot(partial(self._try_load), partial(self._crawl))
Expand Down Expand Up @@ -189,9 +190,13 @@
# Scanning ANY FILE and ANONYMOUS FUNCTION grants
tasks.append(partial(self.grants, catalog=catalog, any_file=True))
tasks.append(partial(self.grants, catalog=catalog, anonymous_function=True))
# scan all databases, even empty ones
for row in self._fetch(f"SHOW DATABASES FROM {escape_sql_identifier(catalog)}"):
tasks.append(partial(self.grants, catalog=catalog, database=row.databaseName))
if not self._include_databases:
# scan all databases, even empty ones
for row in self._fetch(f"SHOW DATABASES FROM {escape_sql_identifier(catalog)}"):
tasks.append(partial(self.grants, catalog=catalog, database=row.databaseName))
else:
for database in self._include_databases:
tasks.append(partial(self.grants, catalog=catalog, database=database))

Check warning on line 199 in src/databricks/labs/ucx/hive_metastore/grants.py

View check run for this annotation

Codecov / codecov/patch

src/databricks/labs/ucx/hive_metastore/grants.py#L199

Added line #L199 was not covered by tests
for table in self._tc.snapshot():
fn = partial(self.grants, catalog=catalog, database=table.database)
# views are recognized as tables
Expand Down
4 changes: 2 additions & 2 deletions src/databricks/labs/ucx/hive_metastore/table_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TableSize:


class TableSizeCrawler(CrawlerBase):
def __init__(self, backend: SqlBackend, schema):
def __init__(self, backend: SqlBackend, schema, include_databases: list[str] | None = None):
"""
Initializes a TablesSizeCrawler instance.

Expand All @@ -31,7 +31,7 @@ def __init__(self, backend: SqlBackend, schema):

self._backend = backend
super().__init__(backend, "hive_metastore", schema, "table_size", TableSize)
self._tables_crawler = TablesCrawler(backend, schema)
self._tables_crawler = TablesCrawler(backend, schema, include_databases)
self._spark = SparkSession.builder.getOrCreate()

def _crawl(self) -> Iterable[TableSize]:
Expand Down
14 changes: 8 additions & 6 deletions src/databricks/labs/ucx/hive_metastore/tables.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import re
import typing
from collections.abc import Iterable, Iterator
from collections.abc import Iterable
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
Expand All @@ -10,7 +10,6 @@

from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend
from databricks.labs.ucx.framework.utils import escape_sql_identifier
from databricks.labs.ucx.mixins.sql import Row

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -151,7 +150,7 @@ class MigrationCount:


class TablesCrawler(CrawlerBase):
def __init__(self, backend: SqlBackend, schema):
def __init__(self, backend: SqlBackend, schema, include_databases: list[str] | None = None):
"""
Initializes a TablesCrawler instance.

Expand All @@ -160,9 +159,12 @@ def __init__(self, backend: SqlBackend, schema):
schema: The schema name for the inventory persistence.
"""
super().__init__(backend, "hive_metastore", schema, "tables", Table)
self._include_database = include_databases

def _all_databases(self) -> Iterator[Row]:
yield from self._fetch("SHOW DATABASES")
def _all_databases(self) -> list[str]:
if not self._include_database:
return [row[0] for row in self._fetch("SHOW DATABASES")]
return self._include_database

def snapshot(self) -> list[Table]:
"""
Expand Down Expand Up @@ -208,7 +210,7 @@ def _crawl(self) -> Iterable[Table]:
"""
tasks = []
catalog = "hive_metastore"
for (database,) in self._all_databases():
for database in self._all_databases():
logger.debug(f"[{catalog}.{database}] listing tables")
for _, table, _is_tmp in self._fetch(
f"SHOW TABLES FROM {escape_sql_identifier(catalog)}.{escape_sql_identifier(database)}"
Expand Down
20 changes: 13 additions & 7 deletions src/databricks/labs/ucx/hive_metastore/tables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import org.apache.spark.sql.functions.{col,lower,upper}
// must follow the same structure as databricks.labs.ucx.hive_metastore.tables.Table
case class TableDetails(catalog: String, database: String, name: String, object_type: String,
table_format: String, location: String, view_text: String, upgraded_to: String, storage_properties: String)

// recording error log in the database
case class TableError(catalog: String, database: String, name: String, error: String)

Expand Down Expand Up @@ -72,20 +71,27 @@ def metadataForAllTables(databases: Seq[String], queue: ConcurrentLinkedQueue[Ta
}).toList.toDF
}

def getInventoryDatabase(): String={
def getConfig(): java.util.Map[String, Any] = {
dbutils.widgets.text("config", "./config.yml")
val configFile = dbutils.widgets.get("config")
val fs = FileSystem.get(new java.net.URI("file:/Workspace"), sc.hadoopConfiguration)
val file = fs.open(new Path(configFile))
val configContents = org.apache.commons.io.IOUtils.toString(file, java.nio.charset.StandardCharsets.UTF_8)
val configObj = new Yaml().load(configContents).asInstanceOf[java.util.Map[String, Any]]
val inventoryDatabase = configObj.get("inventory_database").toString()
return inventoryDatabase
return new Yaml().load(configContents).asInstanceOf[java.util.Map[String, Any]]
}

def getInventoryDatabase(configObj:java.util.Map[String, Any]): String ={
return configObj.get("inventory_database").toString()
}

def getDatabasesToFilter(configObj:java.util.Map[String, Any]): List[String] ={
william-conti marked this conversation as resolved.
Show resolved Hide resolved
return JavaConverters.asScalaBuffer(config.get("include_databases").asInstanceOf[java.util.ArrayList[String]]).toList
}

val inventoryDatabase = getInventoryDatabase()
var df = metadataForAllTables(spark.sharedState.externalCatalog.listDatabases().filter(_ != s"$inventoryDatabase"), failures)
val config = getConfig()
val inventoryDatabase = getInventoryDatabase(config)
val databasesToFilter = getDatabasesToFilter(config)
var df = metadataForAllTables(spark.sharedState.externalCatalog.listDatabases().filter(databasesToFilter.contains(_)), failures)
var columnsToMapLower = Array("catalog","database","name","upgraded_to","storage_properties")
columnsToMapLower.map(column => {
df = df.withColumn(column, lower(col(column)))
Expand Down
14 changes: 8 additions & 6 deletions src/databricks/labs/ucx/hive_metastore/udfs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from collections.abc import Iterable, Iterator
from collections.abc import Iterable
from dataclasses import dataclass
from functools import partial

Expand All @@ -8,7 +8,6 @@

from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend
from databricks.labs.ucx.framework.utils import escape_sql_identifier
from databricks.labs.ucx.mixins.sql import Row

logger = logging.getLogger(__name__)

Expand All @@ -32,7 +31,7 @@ def key(self) -> str:


class UdfsCrawler(CrawlerBase):
def __init__(self, backend: SqlBackend, schema):
def __init__(self, backend: SqlBackend, schema: str, include_databases: list[str] | None = None):
"""
Initializes a UdfsCrawler instance.

Expand All @@ -41,9 +40,12 @@ def __init__(self, backend: SqlBackend, schema):
schema: The schema name for the inventory persistence.
"""
super().__init__(backend, "hive_metastore", schema, "udfs", Udf)
self._include_database = include_databases

def _all_databases(self) -> Iterator[Row]:
yield from self._fetch("SHOW DATABASES")
def _all_databases(self) -> list[str]:
if not self._include_database:
return [row[0] for row in self._fetch("SHOW DATABASES")]
return self._include_database

def snapshot(self) -> list[Udf]:
"""
Expand All @@ -66,7 +68,7 @@ def _crawl(self) -> Iterable[Udf]:
# need to set the current catalog otherwise "SHOW USER FUNCTIONS FROM" is raising error:
# "target schema <database> is not in the current catalog"
self._exec(f"USE CATALOG {escape_sql_identifier(catalog)};")
for (database,) in self._all_databases():
for database in self._all_databases():
try:
logger.debug(f"[{catalog}.{database}] listing udfs")
for (udf,) in self._fetch(
Expand Down
12 changes: 11 additions & 1 deletion src/databricks/labs/ucx/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def run(self):
)
workspace_installation.run()

def configure(self) -> WorkspaceConfig:
def configure(self) -> WorkspaceConfig: # pylint: disable=too-many-locals
william-conti marked this conversation as resolved.
Show resolved Hide resolved
try:
return self._installation.load(WorkspaceConfig)
except NotFound as err:
Expand Down Expand Up @@ -226,6 +226,15 @@ def warehouse_type(_):

configure_groups = ConfigureGroups(self._prompts)
configure_groups.run()
selected_databases = self._prompts.question(
"Comma-separated list of databases to migrate. If not specified, we'll use all "
william-conti marked this conversation as resolved.
Show resolved Hide resolved
"databases in hive_metastore",
default="<ALL>",
)
include_databases = None
if selected_databases != "<ALL>":
include_databases = [x.strip() for x in selected_databases.split(",")]

log_level = self._prompts.question("Log level", default="INFO").upper()
num_threads = int(self._prompts.question("Number of threads", default="8", valid_number=True))

Expand Down Expand Up @@ -258,6 +267,7 @@ def warehouse_type(_):
instance_profile=instance_profile,
spark_conf=spark_conf_dict,
policy_id=policy_id,
include_databases=include_databases,
)
self._installation.save(config)
ws_file_url = self._installation.workspace_link(config.__file__)
Expand Down
4 changes: 2 additions & 2 deletions src/databricks/labs/ucx/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@

Note: This job runs on a separate cluster (named `tacl`) as it requires the proper configuration to have the Table
ACLs enabled and available for retrieval."""
tables = TablesCrawler(sql_backend, cfg.inventory_database)
udfs = UdfsCrawler(sql_backend, cfg.inventory_database)
tables = TablesCrawler(sql_backend, cfg.inventory_database, cfg.include_databases)
udfs = UdfsCrawler(sql_backend, cfg.inventory_database, cfg.include_databases)

Check warning on line 56 in src/databricks/labs/ucx/runtime.py

View check run for this annotation

Codecov / codecov/patch

src/databricks/labs/ucx/runtime.py#L55-L56

Added lines #L55 - L56 were not covered by tests
grants = GrantsCrawler(tables, udfs)
grants.snapshot()

Expand Down
5 changes: 4 additions & 1 deletion tests/integration/hive_metastore/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def test_describe_all_tables_in_databases(ws, sql_backend, inventory_schema, mak
f"view={view.full_name}"
)

tables = TablesCrawler(sql_backend, inventory_schema)
schema_c = make_schema(catalog_name="hive_metastore")
make_table(schema_name=schema_c.name)

tables = TablesCrawler(sql_backend, inventory_schema, [schema_a.name, schema_b.name])
william-conti marked this conversation as resolved.
Show resolved Hide resolved

all_tables = {}
for table in tables.snapshot():
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/hive_metastore/test_udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
@retried(on=[NotFound], timeout=timedelta(minutes=2))
def test_describe_all_udfs_in_databases(ws, sql_backend, inventory_schema, make_schema, make_udf):
schema_a = make_schema(catalog_name="hive_metastore")
schema_b = make_schema(catalog_name="hive_metastore")
make_schema(catalog_name="hive_metastore")
udf_a = make_udf(schema_name=schema_a.name)
udf_b = make_udf(schema_name=schema_a.name)
make_udf(schema_name=schema_b.name)

udfs_crawler = UdfsCrawler(sql_backend, inventory_schema)
udfs_crawler = UdfsCrawler(sql_backend, inventory_schema, [schema_a.name, schema_b.name])
actual_grants = udfs_crawler.snapshot()

unique_udf_grants = {
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/hive_metastore/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,13 @@ def test_is_supported_for_sync(table, supported):
)
def test_table_what(table, what):
assert table.what == what


def test_tables_crawler_should_filter_by_database():
rows = {
"SHOW TABLES FROM hive_metastore.database": [("", "table1", ""), ("", "table2", "")],
}
backend = MockBackend(rows=rows)
tables_crawler = TablesCrawler(backend, "default", ["database"])
results = tables_crawler.snapshot()
assert len(results) == 2
william-conti marked this conversation as resolved.
Show resolved Hide resolved
12 changes: 12 additions & 0 deletions tests/unit/hive_metastore/test_udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,15 @@ def test_udfs_returning_error_when_describing():
udf_crawler = UdfsCrawler(backend, "default")
results = udf_crawler.snapshot()
assert len(results) == 0


def test_tables_crawler_should_filter_by_database():
rows = {
"SHOW USER FUNCTIONS FROM hive_metastore.database": [
make_row(("hive_metastore.database.function1",), ["function"]),
],
}
backend = MockBackend(rows=rows)
udf_crawler = UdfsCrawler(backend, "default", ["database"])
results = udf_crawler.snapshot()
assert len(results) == 1
Loading