This repository was archived by the owner on May 17, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 278
/
Copy pathinfo_tree.py
62 lines (47 loc) · 2.14 KB
/
info_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from typing import List, Dict, Optional, Any, Tuple, Union
import attrs
from typing_extensions import Self
from data_diff.table_segment import TableSegment
@attrs.define(frozen=False)
class SegmentInfo:
tables: List[TableSegment]
diff: Optional[List[Union[Tuple[Any, ...], List[Any]]]] = None
diff_schema: Optional[Tuple[Tuple[str, type], ...]] = None
is_diff: Optional[bool] = None
diff_count: Optional[int] = None
rowcounts: Dict[int, int] = attrs.field(factory=dict)
max_rows: Optional[int] = None
def set_diff(
self, diff: List[Union[Tuple[Any, ...], List[Any]]], schema: Optional[Tuple[Tuple[str, type]]] = None
) -> None:
self.diff_schema = schema
self.diff = diff
self.diff_count = len(diff)
self.is_diff = self.diff_count > 0
def update_from_children(self, child_infos) -> None:
child_infos = list(child_infos)
assert child_infos
# self.diff = list(chain(*[c.diff for c in child_infos]))
self.diff_count = sum(c.diff_count for c in child_infos if c.diff_count is not None)
self.is_diff = any(c.is_diff for c in child_infos)
self.diff_schema = next((child.diff_schema for child in child_infos if child.diff_schema is not None), None)
self.diff = sum((c.diff for c in child_infos if c.diff is not None), [])
self.rowcounts = {
1: sum(c.rowcounts[1] for c in child_infos if c.rowcounts),
2: sum(c.rowcounts[2] for c in child_infos if c.rowcounts),
}
@attrs.define(frozen=True)
class InfoTree:
SEGMENT_INFO_CLASS = SegmentInfo
info: SegmentInfo
children: List["InfoTree"] = attrs.field(factory=list)
def add_node(self, table1: TableSegment, table2: TableSegment, max_rows: Optional[int] = None) -> Self:
cls = self.__class__
node = cls(cls.SEGMENT_INFO_CLASS([table1, table2], max_rows=max_rows))
self.children.append(node)
return node
def aggregate_info(self) -> None:
if self.children:
for c in self.children:
c.aggregate_info()
self.info.update_from_children(c.info for c in self.children)