diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py index 0984d9f1..dd0c3664 100644 --- a/data_diff/hashdiff_tables.py +++ b/data_diff/hashdiff_tables.py @@ -39,22 +39,34 @@ def diff_sets( ignored_columns1: Collection[str], ignored_columns2: Collection[str], ) -> Iterator: - # Differ only by columns of interest (PKs+relevant-ignored). But yield with ignored ones! - sa: Set[_Row] = {tuple(val for col, val in safezip(columns1, row) if col not in ignored_columns1) for row in a} - sb: Set[_Row] = {tuple(val for col, val in safezip(columns2, row) if col not in ignored_columns2) for row in b} - - # The first items are always the PK (see TableSegment.relevant_columns) - diffs_by_pks: Dict[_PK, List[Tuple[_Op, _Row]]] = defaultdict(list) + # Group full rows by PKs on each side. The first items are the PK: TableSegment.relevant_columns + rows_by_pks1: Dict[_PK, List[_Row]] = defaultdict(list) + rows_by_pks2: Dict[_PK, List[_Row]] = defaultdict(list) for row in a: pk: _PK = tuple(val for col, val in zip(key_columns1, row)) - cutrow: _Row = tuple(val for col, val in zip(columns1, row) if col not in ignored_columns1) - if cutrow not in sb: - diffs_by_pks[pk].append(("-", row)) + rows_by_pks1[pk].append(row) for row in b: pk: _PK = tuple(val for col, val in zip(key_columns2, row)) - cutrow: _Row = tuple(val for col, val in zip(columns2, row) if col not in ignored_columns2) - if cutrow not in sa: - diffs_by_pks[pk].append(("+", row)) + rows_by_pks2[pk].append(row) + + # Mind that the same pk MUST go in full with all the -/+ rows all at once, for grouping. + diffs_by_pks: Dict[_PK, List[Tuple[_Op, _Row]]] = defaultdict(list) + for pk in sorted(set(rows_by_pks1) | set(rows_by_pks2)): + cutrows1: List[_Row] = [ + tuple(val for col, val in zip(columns1, row1) if col not in ignored_columns1) for row1 in rows_by_pks1[pk] + ] + cutrows2: List[_Row] = [ + tuple(val for col, val in zip(columns2, row2) if col not in ignored_columns2) for row2 in rows_by_pks2[pk] + ] + + # Either side has 0 rows: a clearly exclusive row. + # Either side has 2+ rows: duplicates on either side, yield it all regardless of values. + # Both sides == 1: non-duplicate, non-exclusive, so check for values of interest. + if len(cutrows1) != 1 or len(cutrows2) != 1 or cutrows1 != cutrows2: + for row1 in rows_by_pks1[pk]: + diffs_by_pks[pk].append(("-", row1)) + for row2 in rows_by_pks2[pk]: + diffs_by_pks[pk].append(("+", row2)) warned_diff_cols = set() for diffs in (diffs_by_pks[pk] for pk in sorted(diffs_by_pks)): diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index 9a975089..705bf55e 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -818,11 +818,11 @@ def test_simple2(self): V1 = N + 1 V2 = N * 1000 + 2 - diffs = [(i, i + N) for i in range(N)] + diffs = [(i + 1, i + N) for i in range(N)] # pk=[1..1000], no dupes self.connection.query( [ - self.src_table.insert_rows(diffs + [(K, V1)]), - self.dst_table.insert_rows(diffs + [(0, V2)]), + self.src_table.insert_rows(diffs + [(K, V1)]), # exclusive pk=1001 + self.dst_table.insert_rows(diffs + [(0, V2)]), # exclusive pk=0 commit, ] )