Skip to content
This repository has been archived by the owner on May 17, 2024. It is now read-only.

Commit

Permalink
Merge pull request #258 from datafold/materialize_all_rows
Browse files Browse the repository at this point in the history
Added --materialize-all-rows switch + tests
  • Loading branch information
erezsh authored Oct 17, 2022
2 parents 6a5ed77 + 3bf7e1c commit 52a7092
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 14 deletions.
19 changes: 17 additions & 2 deletions data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .utils import eval_name_template, remove_password_from_url, safezip, match_like
from .diff_tables import Algorithm
from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR
from .joindiff_tables import JoinDiffer
from .joindiff_tables import TABLE_WRITE_LIMIT, JoinDiffer
from .table_segment import TableSegment
from .databases.database_types import create_schema
from .databases.connect import connect
Expand Down Expand Up @@ -144,7 +144,18 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
@click.option(
"--sample-exclusive-rows",
is_flag=True,
help="Sample several rows that only appear in one of the tables, but not the other.",
help="Sample several rows that only appear in one of the tables, but not the other. (joindiff only)",
)
@click.option(
"--materialize-all-rows",
is_flag=True,
help="Materialize every row, even if they are the same, instead of just the differing rows. (joindiff only)",
)
@click.option(
"--table-write-limit",
default=TABLE_WRITE_LIMIT,
help=f"Maximum number of rows to write when creating materialized or sample tables, per thread. Default={TABLE_WRITE_LIMIT}",
metavar="COUNT",
)
@click.option(
"-j",
Expand Down Expand Up @@ -214,6 +225,8 @@ def _main(
where,
assume_unique_key,
sample_exclusive_rows,
materialize_all_rows,
table_write_limit,
materialize,
threads1=None,
threads2=None,
Expand Down Expand Up @@ -303,6 +316,8 @@ def _main(
max_threadpool_size=threads and threads * 2,
validate_unique_key=not assume_unique_key,
sample_exclusive_rows=sample_exclusive_rows,
materialize_all_rows=materialize_all_rows,
table_write_limit=table_write_limit,
materialize_to_table=materialize and db1.parse_table_name(eval_name_template(materialize)),
)
else:
Expand Down
27 changes: 16 additions & 11 deletions data_diff/joindiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

logger = logging.getLogger("joindiff_tables")

WRITE_LIMIT = 1000
TABLE_WRITE_LIMIT = 1000


def merge_dicts(dicts):
Expand Down Expand Up @@ -115,13 +115,14 @@ class JoinDiffer(TableDiffer):
Future versions will detect UNIQUE constraints in the schema.
sample_exclusive_rows (bool): Enable/disable sampling of exclusive rows. Creates a temporary table.
materialize_to_table (DbPath, optional): Path of new table to write diff results to. Disabled if not provided.
write_limit (int): Maximum number of rows to write when materializing, per thread.
table_write_limit (int): Maximum number of rows to write when materializing, per thread.
"""

validate_unique_key: bool = True
sample_exclusive_rows: bool = True
materialize_to_table: DbPath = None
write_limit: int = WRITE_LIMIT
materialize_all_rows: bool = False
table_write_limit: int = TABLE_WRITE_LIMIT
stats: dict = {}

def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
Expand Down Expand Up @@ -165,15 +166,20 @@ def _diff_segments(
)

db = table1.database
diff_rows, a_cols, b_cols, is_diff_cols = self._create_outer_join(table1, table2)
diff_rows, a_cols, b_cols, is_diff_cols, all_rows = self._create_outer_join(table1, table2)

with self._run_in_background(
partial(self._collect_stats, 1, table1),
partial(self._collect_stats, 2, table2),
partial(self._test_null_keys, table1, table2),
partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols),
partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols),
partial(self._materialize_diff, db, diff_rows, segment_index=segment_index)
partial(
self._materialize_diff,
db,
all_rows if self.materialize_all_rows else diff_rows,
segment_index=segment_index,
)
if self.materialize_to_table
else None,
):
Expand Down Expand Up @@ -263,10 +269,9 @@ def _create_outer_join(self, table1, table2):
a_cols = {f"table1_{c}": NormalizeAsString(a[c]) for c in cols1}
b_cols = {f"table2_{c}": NormalizeAsString(b[c]) for c in cols2}

diff_rows = _outerjoin(db, a, b, keys1, keys2, {**is_diff_cols, **a_cols, **b_cols}).where(
or_(this[c] == 1 for c in is_diff_cols)
)
return diff_rows, a_cols, b_cols, is_diff_cols
all_rows = _outerjoin(db, a, b, keys1, keys2, {**is_diff_cols, **a_cols, **b_cols})
diff_rows = all_rows.where(or_(this[c] == 1 for c in is_diff_cols))
return diff_rows, a_cols, b_cols, is_diff_cols, all_rows

def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols):
logger.info("Counting differences per column")
Expand All @@ -293,7 +298,7 @@ def exclusive_rows(expr):
c = Compiler(db)
name = c.new_unique_table_name("temp_table")
exclusive_rows = table(name, schema=expr.source_table.schema)
yield create_temp_table(c, exclusive_rows, expr.limit(self.write_limit))
yield create_temp_table(c, exclusive_rows, expr.limit(self.table_write_limit))

count = yield exclusive_rows.count()
self.stats["exclusive_count"] = self.stats.get("exclusive_count", 0) + count[0][0]
Expand All @@ -309,5 +314,5 @@ def exclusive_rows(expr):
def _materialize_diff(self, db, diff_rows, segment_index=None):
assert self.materialize_to_table

append_to_table(db, self.materialize_to_table, diff_rows.limit(self.write_limit))
append_to_table(db, self.materialize_to_table, diff_rows.limit(self.table_write_limit))
logger.info("Materialized diff to table '%s'.", ".".join(self.materialize_to_table))
11 changes: 10 additions & 1 deletion tests/test_joindiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,18 @@ def test_diff_small_tables(self):

t = TablePath(materialize_path)
rows = self.connection.query(t.select(), List[tuple])
self.connection.query(t.drop())
# is_xa, is_xb, is_diff1, is_diff2, row1, row2
assert rows == [(1, 0, 1, 1) + expected_row + (None, None)], rows
self.connection.query(t.drop())

# Test materialize all rows
mdiffer = mdiffer.replace(materialize_all_rows=True)
diff = list(mdiffer.diff_tables(self.table, self.table2))
self.assertEqual(expected, diff)
rows = self.connection.query(t.select(), List[tuple])
assert len(rows) == 2, len(rows)
self.connection.query(t.drop())


def test_diff_table_above_bisection_threshold(self):
time = "2022-01-01 00:00:00"
Expand Down

0 comments on commit 52a7092

Please sign in to comment.