From 8944e5fe81a0443e9fd6303308773698e07bac8a Mon Sep 17 00:00:00 2001
From: Sergey Vasilyev <sv@datafold.com>
Date: Thu, 11 Jan 2024 12:20:19 +0100
Subject: [PATCH] Detect duplicate rows on each side

---
 data_diff/hashdiff_tables.py | 36 ++++++++++++++++++++++++------------
 tests/test_diff_tables.py    |  6 +++---
 2 files changed, 27 insertions(+), 15 deletions(-)

diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py
index 0984d9f1..dd0c3664 100644
--- a/data_diff/hashdiff_tables.py
+++ b/data_diff/hashdiff_tables.py
@@ -39,22 +39,34 @@ def diff_sets(
     ignored_columns1: Collection[str],
     ignored_columns2: Collection[str],
 ) -> Iterator:
-    # Differ only by columns of interest (PKs+relevant-ignored). But yield with ignored ones!
-    sa: Set[_Row] = {tuple(val for col, val in safezip(columns1, row) if col not in ignored_columns1) for row in a}
-    sb: Set[_Row] = {tuple(val for col, val in safezip(columns2, row) if col not in ignored_columns2) for row in b}
-
-    # The first items are always the PK (see TableSegment.relevant_columns)
-    diffs_by_pks: Dict[_PK, List[Tuple[_Op, _Row]]] = defaultdict(list)
+    # Group full rows by PKs on each side. The first items are the PK: TableSegment.relevant_columns
+    rows_by_pks1: Dict[_PK, List[_Row]] = defaultdict(list)
+    rows_by_pks2: Dict[_PK, List[_Row]] = defaultdict(list)
     for row in a:
         pk: _PK = tuple(val for col, val in zip(key_columns1, row))
-        cutrow: _Row = tuple(val for col, val in zip(columns1, row) if col not in ignored_columns1)
-        if cutrow not in sb:
-            diffs_by_pks[pk].append(("-", row))
+        rows_by_pks1[pk].append(row)
     for row in b:
         pk: _PK = tuple(val for col, val in zip(key_columns2, row))
-        cutrow: _Row = tuple(val for col, val in zip(columns2, row) if col not in ignored_columns2)
-        if cutrow not in sa:
-            diffs_by_pks[pk].append(("+", row))
+        rows_by_pks2[pk].append(row)
+
+    # Mind that the same pk MUST go in full with all the -/+ rows all at once, for grouping.
+    diffs_by_pks: Dict[_PK, List[Tuple[_Op, _Row]]] = defaultdict(list)
+    for pk in sorted(set(rows_by_pks1) | set(rows_by_pks2)):
+        cutrows1: List[_Row] = [
+            tuple(val for col, val in zip(columns1, row1) if col not in ignored_columns1) for row1 in rows_by_pks1[pk]
+        ]
+        cutrows2: List[_Row] = [
+            tuple(val for col, val in zip(columns2, row2) if col not in ignored_columns2) for row2 in rows_by_pks2[pk]
+        ]
+
+        # Either side has 0 rows: a clearly exclusive row.
+        # Either side has 2+ rows: duplicates on either side, yield it all regardless of values.
+        # Both sides == 1: non-duplicate, non-exclusive, so check for values of interest.
+        if len(cutrows1) != 1 or len(cutrows2) != 1 or cutrows1 != cutrows2:
+            for row1 in rows_by_pks1[pk]:
+                diffs_by_pks[pk].append(("-", row1))
+            for row2 in rows_by_pks2[pk]:
+                diffs_by_pks[pk].append(("+", row2))
 
     warned_diff_cols = set()
     for diffs in (diffs_by_pks[pk] for pk in sorted(diffs_by_pks)):
diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py
index 9a975089..705bf55e 100644
--- a/tests/test_diff_tables.py
+++ b/tests/test_diff_tables.py
@@ -818,11 +818,11 @@ def test_simple2(self):
         V1 = N + 1
         V2 = N * 1000 + 2
 
-        diffs = [(i, i + N) for i in range(N)]
+        diffs = [(i + 1, i + N) for i in range(N)]  # pk=[1..1000], no dupes
         self.connection.query(
             [
-                self.src_table.insert_rows(diffs + [(K, V1)]),
-                self.dst_table.insert_rows(diffs + [(0, V2)]),
+                self.src_table.insert_rows(diffs + [(K, V1)]),  # exclusive pk=1001
+                self.dst_table.insert_rows(diffs + [(0, V2)]),  # exclusive pk=0
                 commit,
             ]
         )