from typing import List, Dict, Optional, Any, Tuple, Union

import attrs
from typing_extensions import Self

from data_diff.table_segment import TableSegment


@attrs.define(frozen=False)
class SegmentInfo:
    tables: List[TableSegment]

    diff: Optional[List[Union[Tuple[Any, ...], List[Any]]]] = None
    diff_schema: Optional[Tuple[Tuple[str, type], ...]] = None
    is_diff: Optional[bool] = None
    diff_count: Optional[int] = None

    rowcounts: Dict[int, int] = attrs.field(factory=dict)
    max_rows: Optional[int] = None

    def set_diff(
        self, diff: List[Union[Tuple[Any, ...], List[Any]]], schema: Optional[Tuple[Tuple[str, type]]] = None
    ) -> None:
        self.diff_schema = schema
        self.diff = diff
        self.diff_count = len(diff)
        self.is_diff = self.diff_count > 0

    def update_from_children(self, child_infos) -> None:
        child_infos = list(child_infos)
        assert child_infos

        # self.diff = list(chain(*[c.diff for c in child_infos]))
        self.diff_count = sum(c.diff_count for c in child_infos if c.diff_count is not None)
        self.is_diff = any(c.is_diff for c in child_infos)
        self.diff_schema = next((child.diff_schema for child in child_infos if child.diff_schema is not None), None)
        self.diff = sum((c.diff for c in child_infos if c.diff is not None), [])

        self.rowcounts = {
            1: sum(c.rowcounts[1] for c in child_infos if c.rowcounts),
            2: sum(c.rowcounts[2] for c in child_infos if c.rowcounts),
        }


@attrs.define(frozen=True)
class InfoTree:
    SEGMENT_INFO_CLASS = SegmentInfo

    info: SegmentInfo
    children: List["InfoTree"] = attrs.field(factory=list)

    def add_node(self, table1: TableSegment, table2: TableSegment, max_rows: Optional[int] = None) -> Self:
        cls = self.__class__
        node = cls(cls.SEGMENT_INFO_CLASS([table1, table2], max_rows=max_rows))
        self.children.append(node)
        return node

    def aggregate_info(self) -> None:
        if self.children:
            for c in self.children:
                c.aggregate_info()
            self.info.update_from_children(c.info for c in self.children)