Skip to content
This repository has been archived by the owner on May 17, 2024. It is now read-only.

Commit

Permalink
Merge pull request #315 from datafold/adjust_pr314
Browse files Browse the repository at this point in the history
Adjustments for  PR #314
  • Loading branch information
erezsh authored Nov 25, 2022
2 parents 779892d + d304e1a commit c905430
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 7 deletions.
17 changes: 10 additions & 7 deletions data_diff/hashdiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@


def diff_sets(a: set, b: set) -> Iterator:
s1 = set(a)
s2 = set(b)
d = defaultdict(list)
sa = set(a)
sb = set(b)

# The first item is always the key (see TableDiffer.relevant_columns)
for i in s1 - s2:
d[i[0]].append(("-", i))
for i in s2 - s1:
d[i[0]].append(("+", i))
# TODO update when we add compound keys to hashdiff
d = defaultdict(list)
for row in a:
if row not in sb:
d[row[0]].append(("-", row))
for row in b:
if row not in sa:
d[row[0]].append(("+", row))

for _k, v in sorted(d.items(), key=lambda i: i[0]):
yield from v
Expand Down
44 changes: 44 additions & 0 deletions tests/test_diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,3 +745,47 @@ def test_info_tree_root(self):
assert info_tree.info.is_diff
assert info_tree.info.diff_count == 1000
self.assertEqual(info_tree.info.rowcounts, {1: 1000, 2: 2000})


class TestDuplicateTables(DiffTestCase):
db_cls = db.MySQL

src_schema = {"id": int, "data": str}
dst_schema = {"id": int, "data": str}

def setUp(self):
"""
table 1:
(12, 'ABCDE'),
(12, 'ABCDE');
table 2:
(4,'ABCDEF'),
(4,'ABCDE'),
(4,'ABCDE'),
(6,'ABCDE'),
(6,'ABCDE'),
(6,'ABCDE');
"""

super().setUp()

src_values = [(12, "ABCDE"), (12, "ABCDE")]
dst_values = [(4, "ABCDEF"), (4, "ABCDE"), (4, "ABCDE"), (6, "ABCDE"), (6, "ABCDE"), (6, "ABCDE")]

self.diffs = [("-", (str(r[0]), r[1])) for r in src_values] + [("+", (str(r[0]), r[1])) for r in dst_values]

self.connection.query([self.src_table.insert_rows(src_values), self.dst_table.insert_rows(dst_values), commit])

self.a = _table_segment(
self.connection, self.table_src_path, "id", extra_columns=("data",), case_sensitive=False
)
self.b = _table_segment(
self.connection, self.table_dst_path, "id", extra_columns=("data",), case_sensitive=False
)

def test_duplicates(self):
"""If there are duplicates in data, we want to return them as well"""

differ = HashDiffer(bisection_factor=2, bisection_threshold=4)
diff = list(differ.diff_tables(self.a, self.b))
self.assertEqual(diff, self.diffs)

0 comments on commit c905430

Please sign in to comment.