diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 5ca5e15b..06b1cd60 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -12,7 +12,7 @@ from .utils import eval_name_template, remove_password_from_url, safezip, match_like from .diff_tables import Algorithm from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR -from .joindiff_tables import JoinDiffer +from .joindiff_tables import TABLE_WRITE_LIMIT, JoinDiffer from .table_segment import TableSegment from .databases.database_types import create_schema from .databases.connect import connect @@ -144,7 +144,18 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - @click.option( "--sample-exclusive-rows", is_flag=True, - help="Sample several rows that only appear in one of the tables, but not the other.", + help="Sample several rows that only appear in one of the tables, but not the other. (joindiff only)", +) +@click.option( + "--materialize-all-rows", + is_flag=True, + help="Materialize every row, even if they are the same, instead of just the differing rows. (joindiff only)", +) +@click.option( + "--table-write-limit", + default=TABLE_WRITE_LIMIT, + help=f"Maximum number of rows to write when creating materialized or sample tables, per thread. Default={TABLE_WRITE_LIMIT}", + metavar="COUNT", ) @click.option( "-j", @@ -214,6 +225,8 @@ def _main( where, assume_unique_key, sample_exclusive_rows, + materialize_all_rows, + table_write_limit, materialize, threads1=None, threads2=None, @@ -303,6 +316,8 @@ def _main( max_threadpool_size=threads and threads * 2, validate_unique_key=not assume_unique_key, sample_exclusive_rows=sample_exclusive_rows, + materialize_all_rows=materialize_all_rows, + table_write_limit=table_write_limit, materialize_to_table=materialize and db1.parse_table_name(eval_name_template(materialize)), ) else: diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 641f11d0..b630d66e 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -29,7 +29,7 @@ logger = logging.getLogger("joindiff_tables") -WRITE_LIMIT = 1000 +TABLE_WRITE_LIMIT = 1000 def merge_dicts(dicts): @@ -115,13 +115,14 @@ class JoinDiffer(TableDiffer): Future versions will detect UNIQUE constraints in the schema. sample_exclusive_rows (bool): Enable/disable sampling of exclusive rows. Creates a temporary table. materialize_to_table (DbPath, optional): Path of new table to write diff results to. Disabled if not provided. - write_limit (int): Maximum number of rows to write when materializing, per thread. + table_write_limit (int): Maximum number of rows to write when materializing, per thread. """ validate_unique_key: bool = True sample_exclusive_rows: bool = True materialize_to_table: DbPath = None - write_limit: int = WRITE_LIMIT + materialize_all_rows: bool = False + table_write_limit: int = TABLE_WRITE_LIMIT stats: dict = {} def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult: @@ -165,7 +166,7 @@ def _diff_segments( ) db = table1.database - diff_rows, a_cols, b_cols, is_diff_cols = self._create_outer_join(table1, table2) + diff_rows, a_cols, b_cols, is_diff_cols, all_rows = self._create_outer_join(table1, table2) with self._run_in_background( partial(self._collect_stats, 1, table1), @@ -173,7 +174,12 @@ def _diff_segments( partial(self._test_null_keys, table1, table2), partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols), partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols), - partial(self._materialize_diff, db, diff_rows, segment_index=segment_index) + partial( + self._materialize_diff, + db, + all_rows if self.materialize_all_rows else diff_rows, + segment_index=segment_index, + ) if self.materialize_to_table else None, ): @@ -263,10 +269,9 @@ def _create_outer_join(self, table1, table2): a_cols = {f"table1_{c}": NormalizeAsString(a[c]) for c in cols1} b_cols = {f"table2_{c}": NormalizeAsString(b[c]) for c in cols2} - diff_rows = _outerjoin(db, a, b, keys1, keys2, {**is_diff_cols, **a_cols, **b_cols}).where( - or_(this[c] == 1 for c in is_diff_cols) - ) - return diff_rows, a_cols, b_cols, is_diff_cols + all_rows = _outerjoin(db, a, b, keys1, keys2, {**is_diff_cols, **a_cols, **b_cols}) + diff_rows = all_rows.where(or_(this[c] == 1 for c in is_diff_cols)) + return diff_rows, a_cols, b_cols, is_diff_cols, all_rows def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols): logger.info("Counting differences per column") @@ -293,7 +298,7 @@ def exclusive_rows(expr): c = Compiler(db) name = c.new_unique_table_name("temp_table") exclusive_rows = table(name, schema=expr.source_table.schema) - yield create_temp_table(c, exclusive_rows, expr.limit(self.write_limit)) + yield create_temp_table(c, exclusive_rows, expr.limit(self.table_write_limit)) count = yield exclusive_rows.count() self.stats["exclusive_count"] = self.stats.get("exclusive_count", 0) + count[0][0] @@ -309,5 +314,5 @@ def exclusive_rows(expr): def _materialize_diff(self, db, diff_rows, segment_index=None): assert self.materialize_to_table - append_to_table(db, self.materialize_to_table, diff_rows.limit(self.write_limit)) + append_to_table(db, self.materialize_to_table, diff_rows.limit(self.table_write_limit)) logger.info("Materialized diff to table '%s'.", ".".join(self.materialize_to_table)) diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index d3db82e0..60279f6f 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -133,9 +133,18 @@ def test_diff_small_tables(self): t = TablePath(materialize_path) rows = self.connection.query(t.select(), List[tuple]) - self.connection.query(t.drop()) # is_xa, is_xb, is_diff1, is_diff2, row1, row2 assert rows == [(1, 0, 1, 1) + expected_row + (None, None)], rows + self.connection.query(t.drop()) + + # Test materialize all rows + mdiffer = mdiffer.replace(materialize_all_rows=True) + diff = list(mdiffer.diff_tables(self.table, self.table2)) + self.assertEqual(expected, diff) + rows = self.connection.query(t.select(), List[tuple]) + assert len(rows) == 2, len(rows) + self.connection.query(t.drop()) + def test_diff_table_above_bisection_threshold(self): time = "2022-01-01 00:00:00"