Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit c1b24ef

Browse files
author
Sergey Vasilyev
committed
Convert the runtype classes to attrs
`attrs` is much more beneficial: * `attrs` is supported by type checkers, such as MyPy & IDEs * `attrs` is widely used and industry-proven * `attrs` is explicit in its declarations, there is no magic * `attrs` has slots But mainly for the first item — type checking by type checkers. Those that were runtype classes, are frozen. Those that were not, are unfrozen for now, but we can freeze them later if and where it works (the stricter, the better).
1 parent 1e0f310 commit c1b24ef

19 files changed

+191
-182
lines changed

Diff for: data_diff/abcs/compiler.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from abc import ABC
22

3+
import attrs
34

5+
6+
@attrs.define(frozen=False)
47
class AbstractCompiler(ABC):
58
pass
69

710

11+
@attrs.define(frozen=False, eq=False)
812
class Compilable(ABC):
913
pass

Diff for: data_diff/abcs/database_types.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Tuple, Union
44
from datetime import datetime
55

6-
from runtype import dataclass
6+
import attrs
77

88
from data_diff.utils import ArithAlphanumeric, ArithUUID, Unknown
99

@@ -13,14 +13,14 @@
1313
DbTime = datetime
1414

1515

16-
@dataclass
16+
@attrs.define(frozen=True)
1717
class ColType:
1818
@property
1919
def supported(self) -> bool:
2020
return True
2121

2222

23-
@dataclass
23+
@attrs.define(frozen=True)
2424
class PrecisionType(ColType):
2525
precision: int
2626
rounds: Union[bool, Unknown] = Unknown
@@ -50,7 +50,7 @@ class Date(TemporalType):
5050
pass
5151

5252

53-
@dataclass
53+
@attrs.define(frozen=True)
5454
class NumericType(ColType):
5555
# 'precision' signifies how many fractional digits (after the dot) we want to compare
5656
precision: int
@@ -84,7 +84,7 @@ def python_type(self) -> type:
8484
return decimal.Decimal
8585

8686

87-
@dataclass
87+
@attrs.define(frozen=True)
8888
class StringType(ColType):
8989
python_type = str
9090

@@ -122,7 +122,7 @@ class String_VaryingAlphanum(String_Alphanum):
122122
pass
123123

124124

125-
@dataclass
125+
@attrs.define(frozen=True)
126126
class String_FixedAlphanum(String_Alphanum):
127127
length: int
128128

@@ -132,20 +132,20 @@ def make_value(self, value):
132132
return self.python_type(value, max_len=self.length)
133133

134134

135-
@dataclass
135+
@attrs.define(frozen=True)
136136
class Text(StringType):
137137
@property
138138
def supported(self) -> bool:
139139
return False
140140

141141

142142
# In majority of DBMSes, it is called JSON/JSONB. Only in Snowflake, it is OBJECT.
143-
@dataclass
143+
@attrs.define(frozen=True)
144144
class JSON(ColType):
145145
pass
146146

147147

148-
@dataclass
148+
@attrs.define(frozen=True)
149149
class Array(ColType):
150150
item_type: ColType
151151

@@ -155,21 +155,21 @@ class Array(ColType):
155155
# For example, in BigQuery:
156156
# - https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#struct_type
157157
# - https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#struct_literals
158-
@dataclass
158+
@attrs.define(frozen=True)
159159
class Struct(ColType):
160160
pass
161161

162162

163-
@dataclass
163+
@attrs.define(frozen=True)
164164
class Integer(NumericType, IKey):
165165
precision: int = 0
166166
python_type: type = int
167167

168-
def __post_init__(self):
168+
def __attrs_post_init__(self):
169169
assert self.precision == 0
170170

171171

172-
@dataclass
172+
@attrs.define(frozen=True)
173173
class UnknownColType(ColType):
174174
text: str
175175

Diff for: data_diff/cloud/datafold_api.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import base64
2-
import dataclasses
32
import enum
43
import time
54
from typing import Any, Dict, List, Optional, Type, Tuple
65

6+
import attrs
77
import pydantic
88
import requests
99
from typing_extensions import Self
@@ -178,13 +178,13 @@ class TCloudApiDataSourceTestResult(pydantic.BaseModel):
178178
result: Optional[TCloudDataSourceTestResult]
179179

180180

181-
@dataclasses.dataclass
181+
@attrs.define(frozen=True)
182182
class DatafoldAPI:
183183
api_key: str
184184
host: str = "https://app.datafold.com"
185185
timeout: int = 30
186186

187-
def __post_init__(self):
187+
def __attrs_post_init__(self):
188188
self.host = self.host.rstrip("/")
189189
self.headers = {
190190
"Authorization": f"Key {self.api_key}",

Diff for: data_diff/databases/_connect.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from itertools import zip_longest
44
from contextlib import suppress
55
import weakref
6+
7+
import attrs
68
import dsnparse
79
import toml
810

9-
from runtype import dataclass
1011
from typing_extensions import Self
1112

1213
from data_diff.databases.base import Database, ThreadedDatabase
@@ -25,7 +26,7 @@
2526
from data_diff.databases.mssql import MsSQL
2627

2728

28-
@dataclass
29+
@attrs.define(frozen=True)
2930
class MatchUriPath:
3031
database_cls: Type[Database]
3132

Diff for: data_diff/databases/base.py

+19-22
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import abc
22
import functools
3-
from dataclasses import field
3+
import random
44
from datetime import datetime
55
import math
66
import sys
@@ -14,7 +14,7 @@
1414
import decimal
1515
import contextvars
1616

17-
from runtype import dataclass
17+
import attrs
1818
from typing_extensions import Self
1919

2020
from data_diff.abcs.compiler import AbstractCompiler
@@ -90,12 +90,7 @@ class CompileError(Exception):
9090
pass
9191

9292

93-
# TODO: remove once switched to attrs, where ForwardRef[]/strings are resolved.
94-
class _RuntypeHackToFixCicularRefrencedDatabase:
95-
dialect: "BaseDialect"
96-
97-
98-
@dataclass
93+
@attrs.define(frozen=True)
9994
class Compiler(AbstractCompiler):
10095
"""
10196
Compiler bears the context for a single compilation.
@@ -107,16 +102,16 @@ class Compiler(AbstractCompiler):
107102
# Database is needed to normalize tables. Dialect is needed for recursive compilations.
108103
# In theory, it is many-to-many relations: e.g. a generic ODBC driver with multiple dialects.
109104
# In practice, we currently bind the dialects to the specific database classes.
110-
database: _RuntypeHackToFixCicularRefrencedDatabase
105+
database: "Database"
111106

112107
in_select: bool = False # Compilation runtime flag
113108
in_join: bool = False # Compilation runtime flag
114109

115-
_table_context: List = field(default_factory=list) # List[ITable]
116-
_subqueries: Dict[str, Any] = field(default_factory=dict) # XXX not thread-safe
110+
_table_context: List = attrs.field(factory=list) # List[ITable]
111+
_subqueries: Dict[str, Any] = attrs.field(factory=dict) # XXX not thread-safe
117112
root: bool = True
118113

119-
_counter: List = field(default_factory=lambda: [0])
114+
_counter: List = attrs.field(factory=lambda: [0])
120115

121116
@property
122117
def dialect(self) -> "BaseDialect":
@@ -136,7 +131,7 @@ def new_unique_table_name(self, prefix="tmp") -> DbPath:
136131
return self.database.dialect.parse_table_name(table_name)
137132

138133
def add_table_context(self, *tables: Sequence, **kw) -> Self:
139-
return self.replace(_table_context=self._table_context + list(tables), **kw)
134+
return attrs.evolve(self, table_context=self._table_context + list(tables), **kw)
140135

141136

142137
def parse_table_name(t):
@@ -271,7 +266,7 @@ def _compile(self, compiler: Compiler, elem) -> str:
271266
if elem is None:
272267
return "NULL"
273268
elif isinstance(elem, Compilable):
274-
return self.render_compilable(compiler.replace(root=False), elem)
269+
return self.render_compilable(attrs.evolve(compiler, root=False), elem)
275270
elif isinstance(elem, str):
276271
return f"'{elem}'"
277272
elif isinstance(elem, (int, float)):
@@ -381,7 +376,7 @@ def render_column(self, c: Compiler, elem: Column) -> str:
381376
return self.quote(elem.name)
382377

383378
def render_cte(self, parent_c: Compiler, elem: Cte) -> str:
384-
c: Compiler = parent_c.replace(_table_context=[], in_select=False)
379+
c: Compiler = attrs.evolve(parent_c, table_context=[], in_select=False)
385380
compiled = self.compile(c, elem.source_table)
386381

387382
name = elem.name or parent_c.new_unique_name()
@@ -494,7 +489,7 @@ def render_tablealias(self, c: Compiler, elem: TableAlias) -> str:
494489
return f"{self.compile(c, elem.source_table)} {self.quote(elem.name)}"
495490

496491
def render_tableop(self, parent_c: Compiler, elem: TableOp) -> str:
497-
c: Compiler = parent_c.replace(in_select=False)
492+
c: Compiler = attrs.evolve(parent_c, in_select=False)
498493
table_expr = f"{self.compile(c, elem.table1)} {elem.op} {self.compile(c, elem.table2)}"
499494
if parent_c.in_select:
500495
table_expr = f"({table_expr}) {c.new_unique_name()}"
@@ -506,7 +501,7 @@ def render__resolvecolumn(self, c: Compiler, elem: _ResolveColumn) -> str:
506501
return self.compile(c, elem._get_resolved())
507502

508503
def render_select(self, parent_c: Compiler, elem: Select) -> str:
509-
c: Compiler = parent_c.replace(in_select=True) # .add_table_context(self.table)
504+
c: Compiler = attrs.evolve(parent_c, in_select=True) # .add_table_context(self.table)
510505
compile_fn = functools.partial(self.compile, c)
511506

512507
columns = ", ".join(map(compile_fn, elem.columns)) if elem.columns else "*"
@@ -544,7 +539,8 @@ def render_select(self, parent_c: Compiler, elem: Select) -> str:
544539

545540
def render_join(self, parent_c: Compiler, elem: Join) -> str:
546541
tables = [
547-
t if isinstance(t, TableAlias) else TableAlias(t, parent_c.new_unique_name()) for t in elem.source_tables
542+
t if isinstance(t, TableAlias) else TableAlias(source_table=t, name=parent_c.new_unique_name())
543+
for t in elem.source_tables
548544
]
549545
c = parent_c.add_table_context(*tables, in_join=True, in_select=False)
550546
op = " JOIN " if elem.op is None else f" {elem.op} JOIN "
@@ -577,7 +573,8 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str:
577573
if isinstance(elem.table, Select) and elem.table.columns is None and elem.table.group_by_exprs is None:
578574
return self.compile(
579575
c,
580-
elem.table.replace(
576+
attrs.evolve(
577+
elem.table,
581578
columns=columns,
582579
group_by_exprs=[Code(k) for k in keys],
583580
having_exprs=elem.having_exprs,
@@ -589,7 +586,7 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str:
589586
having_str = (
590587
" HAVING " + " AND ".join(map(compile_fn, elem.having_exprs)) if elem.having_exprs is not None else ""
591588
)
592-
select = f"SELECT {columns_str} FROM {self.compile(c.replace(in_select=True), elem.table)} GROUP BY {keys_str}{having_str}"
589+
select = f"SELECT {columns_str} FROM {self.compile(attrs.evolve(c, in_select=True), elem.table)} GROUP BY {keys_str}{having_str}"
593590

594591
if c.in_select:
595592
select = f"({select}) {c.new_unique_name()}"
@@ -815,7 +812,7 @@ def set_timezone_to_utc(self) -> str:
815812
T = TypeVar("T", bound=BaseDialect)
816813

817814

818-
@dataclass
815+
@attrs.define(frozen=True)
819816
class QueryResult:
820817
rows: list
821818
columns: Optional[list] = None
@@ -830,7 +827,7 @@ def __getitem__(self, i):
830827
return self.rows[i]
831828

832829

833-
class Database(abc.ABC, _RuntypeHackToFixCicularRefrencedDatabase):
830+
class Database(abc.ABC):
834831
"""Base abstract class for databases.
835832
836833
Used for providing connection code and implementation specific SQL utilities.

Diff for: data_diff/diff_tables.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33

44
import time
55
from abc import ABC, abstractmethod
6-
from dataclasses import field
76
from enum import Enum
87
from contextlib import contextmanager
98
from operator import methodcaller
109
from typing import Dict, Tuple, Iterator, Optional
1110
from concurrent.futures import ThreadPoolExecutor, as_completed
1211

13-
from runtype import dataclass
12+
import attrs
1413

1514
from data_diff.info_tree import InfoTree, SegmentInfo
1615
from data_diff.utils import dbt_diff_string_template, run_as_daemon, safezip, getLogger, truncate_error, Vector
@@ -31,7 +30,7 @@ class Algorithm(Enum):
3130
DiffResult = Iterator[Tuple[str, tuple]] # Iterator[Tuple[Literal["+", "-"], tuple]]
3231

3332

34-
@dataclass
33+
@attrs.define(frozen=True)
3534
class ThreadBase:
3635
"Provides utility methods for optional threading"
3736

@@ -72,7 +71,7 @@ def _run_in_background(self, *funcs):
7271
f.result()
7372

7473

75-
@dataclass
74+
@attrs.define(frozen=True)
7675
class DiffStats:
7776
diff_by_sign: Dict[str, int]
7877
table1_count: int
@@ -82,12 +81,12 @@ class DiffStats:
8281
extra_column_diffs: Optional[Dict[str, int]]
8382

8483

85-
@dataclass
84+
@attrs.define(frozen=True)
8685
class DiffResultWrapper:
8786
diff: iter # DiffResult
8887
info_tree: InfoTree
8988
stats: dict
90-
result_list: list = field(default_factory=list)
89+
result_list: list = attrs.field(factory=list)
9190

9291
def __iter__(self):
9392
yield from self.result_list
@@ -203,7 +202,7 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment, info_tree: Inf
203202

204203
def _diff_tables_wrapper(self, table1: TableSegment, table2: TableSegment, info_tree: InfoTree) -> DiffResult:
205204
if is_tracking_enabled():
206-
options = dict(self)
205+
options = attrs.asdict(self, recurse=False)
207206
options["differ_name"] = type(self).__name__
208207
event_json = create_start_event_json(options)
209208
run_as_daemon(send_event_json, event_json)

0 commit comments

Comments
 (0)