Skip to content
This repository has been archived by the owner on May 17, 2024. It is now read-only.

Added --materialize-all-rows switch + tests #258

Merged
merged 3 commits into from
Oct 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .utils import eval_name_template, remove_password_from_url, safezip, match_like
from .diff_tables import Algorithm
from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR
from .joindiff_tables import JoinDiffer
from .joindiff_tables import TABLE_WRITE_LIMIT, JoinDiffer
from .table_segment import TableSegment
from .databases.database_types import create_schema
from .databases.connect import connect
Expand Down Expand Up @@ -144,7 +144,18 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
@click.option(
"--sample-exclusive-rows",
is_flag=True,
help="Sample several rows that only appear in one of the tables, but not the other.",
help="Sample several rows that only appear in one of the tables, but not the other. (joindiff only)",
)
@click.option(
"--materialize-all-rows",
is_flag=True,
help="Materialize every row, even if they are the same, instead of just the differing rows. (joindiff only)",
)
@click.option(
"--table-write-limit",
default=TABLE_WRITE_LIMIT,
help=f"Maximum number of rows to write when creating materialized or sample tables, per thread. Default={TABLE_WRITE_LIMIT}",
metavar="COUNT",
)
@click.option(
"-j",
Expand Down Expand Up @@ -214,6 +225,8 @@ def _main(
where,
assume_unique_key,
sample_exclusive_rows,
materialize_all_rows,
table_write_limit,
materialize,
threads1=None,
threads2=None,
Expand Down Expand Up @@ -303,6 +316,8 @@ def _main(
max_threadpool_size=threads and threads * 2,
validate_unique_key=not assume_unique_key,
sample_exclusive_rows=sample_exclusive_rows,
materialize_all_rows=materialize_all_rows,
table_write_limit=table_write_limit,
materialize_to_table=materialize and db1.parse_table_name(eval_name_template(materialize)),
)
else:
Expand Down
27 changes: 16 additions & 11 deletions data_diff/joindiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

logger = logging.getLogger("joindiff_tables")

WRITE_LIMIT = 1000
TABLE_WRITE_LIMIT = 1000


def merge_dicts(dicts):
Expand Down Expand Up @@ -115,13 +115,14 @@ class JoinDiffer(TableDiffer):
Future versions will detect UNIQUE constraints in the schema.
sample_exclusive_rows (bool): Enable/disable sampling of exclusive rows. Creates a temporary table.
materialize_to_table (DbPath, optional): Path of new table to write diff results to. Disabled if not provided.
write_limit (int): Maximum number of rows to write when materializing, per thread.
table_write_limit (int): Maximum number of rows to write when materializing, per thread.
"""

validate_unique_key: bool = True
sample_exclusive_rows: bool = True
materialize_to_table: DbPath = None
write_limit: int = WRITE_LIMIT
materialize_all_rows: bool = False
table_write_limit: int = TABLE_WRITE_LIMIT
stats: dict = {}

def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
Expand Down Expand Up @@ -165,15 +166,20 @@ def _diff_segments(
)

db = table1.database
diff_rows, a_cols, b_cols, is_diff_cols = self._create_outer_join(table1, table2)
diff_rows, a_cols, b_cols, is_diff_cols, all_rows = self._create_outer_join(table1, table2)

with self._run_in_background(
partial(self._collect_stats, 1, table1),
partial(self._collect_stats, 2, table2),
partial(self._test_null_keys, table1, table2),
partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols),
partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols),
partial(self._materialize_diff, db, diff_rows, segment_index=segment_index)
partial(
self._materialize_diff,
db,
all_rows if self.materialize_all_rows else diff_rows,
segment_index=segment_index,
)
if self.materialize_to_table
else None,
):
Expand Down Expand Up @@ -263,10 +269,9 @@ def _create_outer_join(self, table1, table2):
a_cols = {f"table1_{c}": NormalizeAsString(a[c]) for c in cols1}
b_cols = {f"table2_{c}": NormalizeAsString(b[c]) for c in cols2}

diff_rows = _outerjoin(db, a, b, keys1, keys2, {**is_diff_cols, **a_cols, **b_cols}).where(
or_(this[c] == 1 for c in is_diff_cols)
)
return diff_rows, a_cols, b_cols, is_diff_cols
all_rows = _outerjoin(db, a, b, keys1, keys2, {**is_diff_cols, **a_cols, **b_cols})
diff_rows = all_rows.where(or_(this[c] == 1 for c in is_diff_cols))
return diff_rows, a_cols, b_cols, is_diff_cols, all_rows

def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols):
logger.info("Counting differences per column")
Expand All @@ -293,7 +298,7 @@ def exclusive_rows(expr):
c = Compiler(db)
name = c.new_unique_table_name("temp_table")
exclusive_rows = table(name, schema=expr.source_table.schema)
yield create_temp_table(c, exclusive_rows, expr.limit(self.write_limit))
yield create_temp_table(c, exclusive_rows, expr.limit(self.table_write_limit))

count = yield exclusive_rows.count()
self.stats["exclusive_count"] = self.stats.get("exclusive_count", 0) + count[0][0]
Expand All @@ -309,5 +314,5 @@ def exclusive_rows(expr):
def _materialize_diff(self, db, diff_rows, segment_index=None):
assert self.materialize_to_table

append_to_table(db, self.materialize_to_table, diff_rows.limit(self.write_limit))
append_to_table(db, self.materialize_to_table, diff_rows.limit(self.table_write_limit))
logger.info("Materialized diff to table '%s'.", ".".join(self.materialize_to_table))
11 changes: 10 additions & 1 deletion tests/test_joindiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,18 @@ def test_diff_small_tables(self):

t = TablePath(materialize_path)
rows = self.connection.query(t.select(), List[tuple])
self.connection.query(t.drop())
# is_xa, is_xb, is_diff1, is_diff2, row1, row2
assert rows == [(1, 0, 1, 1) + expected_row + (None, None)], rows
self.connection.query(t.drop())

# Test materialize all rows
mdiffer = mdiffer.replace(materialize_all_rows=True)
diff = list(mdiffer.diff_tables(self.table, self.table2))
self.assertEqual(expected, diff)
rows = self.connection.query(t.select(), List[tuple])
assert len(rows) == 2, len(rows)
self.connection.query(t.drop())


def test_diff_table_above_bisection_threshold(self):
time = "2022-01-01 00:00:00"
Expand Down