14
14
import dataclasses
15
15
import numbers
16
16
from collections import defaultdict , namedtuple , OrderedDict
17
- from typing import List
17
+ from dataclasses import InitVar
18
+ from typing import Any , ClassVar , List , Optional
18
19
19
20
import numpy as np
20
21
import pytest
@@ -31,6 +32,12 @@ class Feature:
31
32
input_ids : torch .Tensor
32
33
segment_ids : np .ndarray
33
34
35
+ def __eq__ (self , o : object ) -> bool :
36
+ if not isinstance (o , Feature ):
37
+ return NotImplemented
38
+ else :
39
+ return torch .equal (self .input_ids , o .input_ids ) and np .equal (self .segment_ids , o .segment_ids ).all ()
40
+
34
41
@dataclasses .dataclass
35
42
class ModelExample :
36
43
example_ids : List [str ]
@@ -41,6 +48,71 @@ class ModelExample:
41
48
def __post_init__ (self ):
42
49
self .some_constant = 7
43
50
51
+ def __eq__ (self , o : object ) -> bool :
52
+ if not isinstance (o , ModelExample ):
53
+ return NotImplemented
54
+ else :
55
+ return (
56
+ self .example_ids == o .example_ids
57
+ and self .feature == o .feature
58
+ and torch .equal (self .label , o .label )
59
+ and self .some_constant == o .some_constant
60
+ )
61
+
62
+ @dataclasses .dataclass
63
+ class WithClassVar :
64
+ class_var : ClassVar [int ] = 0
65
+ dummy : Any
66
+
67
+ def __eq__ (self , o : object ) -> bool :
68
+ if not isinstance (o , WithClassVar ):
69
+ return NotImplemented
70
+ elif isinstance (self .dummy , torch .Tensor ):
71
+ return torch .equal (self .dummy , o .dummy )
72
+ else :
73
+ return self .dummy == o .dummy
74
+
75
+ @dataclasses .dataclass
76
+ class WithInitVar :
77
+ dummy : Any
78
+ override : InitVar [Optional [Any ]] = None
79
+
80
+ def __post_init__ (self , override : Optional [Any ]):
81
+ if override is not None :
82
+ self .dummy = override
83
+
84
+ def __eq__ (self , o : object ) -> bool :
85
+ if not isinstance (o , WithInitVar ):
86
+ return NotImplemented
87
+ elif isinstance (self .dummy , torch .Tensor ):
88
+ return torch .equal (self .dummy , o .dummy )
89
+ else :
90
+ return self .dummy == o .dummy
91
+
92
+ @dataclasses .dataclass
93
+ class WithClassAndInitVar :
94
+ class_var : ClassVar [torch .Tensor ] = torch .tensor (0 )
95
+ dummy : Any
96
+ override : InitVar [Optional [Any ]] = torch .tensor (1 )
97
+
98
+ def __post_init__ (self , override : Optional [Any ]):
99
+ if override is not None :
100
+ self .dummy = override
101
+
102
+ def __eq__ (self , o : object ) -> bool :
103
+ if not isinstance (o , WithClassAndInitVar ):
104
+ return NotImplemented
105
+ elif isinstance (self .dummy , torch .Tensor ):
106
+ return torch .equal (self .dummy , o .dummy )
107
+ else :
108
+ return self .dummy == o .dummy
109
+
110
+ model_example = ModelExample (
111
+ example_ids = ["i-1" , "i-2" , "i-3" ],
112
+ feature = Feature (input_ids = torch .tensor ([1.0 , 2.0 , 3.0 ]), segment_ids = np .array ([4.0 , 5.0 , 6.0 ])),
113
+ label = torch .tensor ([7.0 , 8.0 , 9.0 ]),
114
+ )
115
+
44
116
to_reduce = {
45
117
"a" : torch .tensor ([1.0 ]), # Tensor
46
118
"b" : [torch .tensor ([2.0 ])], # list
@@ -50,13 +122,18 @@ def __post_init__(self):
50
122
"f" : "this_is_a_dummy_str" , # string
51
123
"g" : 12.0 , # number
52
124
"h" : Feature (input_ids = torch .tensor ([1.0 , 2.0 , 3.0 ]), segment_ids = np .array ([4.0 , 5.0 , 6.0 ])), # dataclass
53
- "i" : ModelExample (
54
- example_ids = ["i-1" , "i-2" , "i-3" ],
55
- feature = Feature (input_ids = torch .tensor ([1.0 , 2.0 , 3.0 ]), segment_ids = np .array ([4.0 , 5.0 , 6.0 ])),
56
- label = torch .tensor ([7.0 , 8.0 , 9.0 ]),
57
- ), # nested dataclass
125
+ "i" : model_example , # nested dataclass
126
+ "j" : WithClassVar (torch .arange (3 )), # dataclass with class variable
127
+ "k" : WithInitVar ("this_gets_overridden" , torch .tensor ([2.0 ])), # dataclass with init-only variable
128
+ "l" : WithClassAndInitVar (model_example , None ), # nested dataclass with class and init-only variables
58
129
}
59
130
131
+ model_example_result = ModelExample (
132
+ example_ids = ["i-1" , "i-2" , "i-3" ],
133
+ feature = Feature (input_ids = torch .tensor ([2.0 , 4.0 , 6.0 ]), segment_ids = np .array ([8.0 , 10.0 , 12.0 ])),
134
+ label = torch .tensor ([14.0 , 16.0 , 18.0 ]),
135
+ )
136
+
60
137
expected_result = {
61
138
"a" : torch .tensor ([2.0 ]),
62
139
"b" : [torch .tensor ([4.0 ])],
@@ -66,32 +143,31 @@ def __post_init__(self):
66
143
"f" : "this_is_a_dummy_str" ,
67
144
"g" : 24.0 ,
68
145
"h" : Feature (input_ids = torch .tensor ([2.0 , 4.0 , 6.0 ]), segment_ids = np .array ([8.0 , 10.0 , 12.0 ])),
69
- "i" : ModelExample (
70
- example_ids = ["i-1" , "i-2" , "i-3" ],
71
- feature = Feature (input_ids = torch .tensor ([2.0 , 4.0 , 6.0 ]), segment_ids = np .array ([8.0 , 10.0 , 12.0 ])),
72
- label = torch .tensor ([14.0 , 16.0 , 18.0 ]),
73
- ),
146
+ "i" : model_example_result ,
147
+ "j" : WithClassVar (torch .arange (0 , 6 , 2 )),
148
+ "k" : WithInitVar (torch .tensor ([4.0 ])),
149
+ "l" : WithClassAndInitVar (model_example_result , None ),
74
150
}
75
151
76
152
reduced = apply_to_collection (to_reduce , (torch .Tensor , numbers .Number , np .ndarray ), lambda x : x * 2 )
77
153
78
- assert isinstance (reduced , dict ), " Type Consistency of dict not preserved"
154
+ assert isinstance (reduced , dict ), "Type Consistency of dict not preserved"
79
155
assert all (x in reduced for x in to_reduce ), "Not all entries of the dict were preserved"
80
156
assert all (
81
157
isinstance (reduced [k ], type (expected_result [k ])) for k in to_reduce
82
158
), "At least one type was not correctly preserved"
83
159
84
160
assert isinstance (reduced ["a" ], torch .Tensor ), "Reduction Result of a Tensor should be a Tensor"
85
- assert torch .allclose (expected_result ["a" ], reduced ["a" ]), "Reduction of a tensor does not yield the expected value"
161
+ assert torch .equal (expected_result ["a" ], reduced ["a" ]), "Reduction of a tensor does not yield the expected value"
86
162
87
163
assert isinstance (reduced ["b" ], list ), "Reduction Result of a list should be a list"
88
164
assert all (
89
- torch .allclose (x , y ) for x , y in zip (reduced ["b" ], expected_result ["b" ])
165
+ torch .equal (x , y ) for x , y in zip (reduced ["b" ], expected_result ["b" ])
90
166
), "At least one value of list reduction did not come out as expected"
91
167
92
168
assert isinstance (reduced ["c" ], tuple ), "Reduction Result of a tuple should be a tuple"
93
169
assert all (
94
- torch .allclose (x , y ) for x , y in zip (reduced ["c" ], expected_result ["c" ])
170
+ torch .equal (x , y ) for x , y in zip (reduced ["c" ], expected_result ["c" ])
95
171
), "At least one value of tuple reduction did not come out as expected"
96
172
97
173
assert isinstance (reduced ["d" ], ntc ), "Type Consistency for named tuple not given"
@@ -109,34 +185,30 @@ def __post_init__(self):
109
185
assert isinstance (reduced ["g" ], numbers .Number ), "Reduction of a number should result in a number"
110
186
assert reduced ["g" ] == expected_result ["g" ], "Reduction of a number did not yield the desired result"
111
187
112
- assert dataclasses .is_dataclass (reduced ["h" ]) and not isinstance (
113
- reduced ["h" ], type
114
- ), "Reduction of a dataclass should result in a dataclass"
115
- assert torch .allclose (
116
- reduced ["h" ].input_ids , expected_result ["h" ].input_ids
117
- ), "Reduction of a dataclass did not yield the desired result"
118
- assert np .allclose (
119
- reduced ["h" ].segment_ids , expected_result ["h" ].segment_ids
120
- ), "Reduction of a dataclass did not yield the desired result"
121
-
122
- assert dataclasses .is_dataclass (reduced ["i" ]) and not isinstance (
123
- reduced ["i" ], type
124
- ), "Reduction of a dataclass should result in a dataclass"
125
- assert dataclasses .is_dataclass (reduced ["i" ].feature ) and not isinstance (
126
- reduced ["i" ].feature , type
127
- ), "Reduction of a nested dataclass should result in a nested dataclass"
128
- assert (
129
- reduced ["i" ].example_ids == expected_result ["i" ].example_ids
130
- ), "Reduction of a nested dataclass did not yield the desired result"
131
- assert torch .allclose (
132
- reduced ["i" ].label , expected_result ["i" ].label
133
- ), "Reduction of a nested dataclass did not yield the desired result"
134
- assert torch .allclose (
135
- reduced ["i" ].feature .input_ids , expected_result ["i" ].feature .input_ids
136
- ), "Reduction of a nested dataclass did not yield the desired result"
137
- assert np .allclose (
138
- reduced ["i" ].feature .segment_ids , expected_result ["i" ].feature .segment_ids
139
- ), "Reduction of a nested dataclass did not yield the desired result"
188
+ def _assert_dataclass_reduction (actual , expected , dataclass_type : str = "" ):
189
+ assert dataclasses .is_dataclass (actual ) and not isinstance (
190
+ actual , type
191
+ ), f"Reduction of a { dataclass_type } dataclass should result in a dataclass"
192
+ for field in dataclasses .fields (actual ):
193
+ if dataclasses .is_dataclass (field .type ):
194
+ _assert_dataclass_reduction (getattr (actual , field .name ), getattr (expected , field .name ), "nested" )
195
+ assert actual == expected , f"Reduction of a { dataclass_type } dataclass did not yield the desired result"
196
+
197
+ _assert_dataclass_reduction (reduced ["h" ], expected_result ["h" ])
198
+
199
+ _assert_dataclass_reduction (reduced ["i" ], expected_result ["i" ])
200
+
201
+ dataclass_type = "ClassVar-containing"
202
+ _assert_dataclass_reduction (reduced ["j" ], expected_result ["j" ], dataclass_type )
203
+ assert WithClassVar .class_var == 0 , f"Reduction of a { dataclass_type } dataclass should not change the class var"
204
+
205
+ _assert_dataclass_reduction (reduced ["k" ], expected_result ["k" ], "InitVar-containing" )
206
+
207
+ dataclass_type = "Class-and-InitVar-containing"
208
+ _assert_dataclass_reduction (reduced ["l" ], expected_result ["l" ], dataclass_type )
209
+ assert torch .equal (
210
+ WithClassAndInitVar .class_var , torch .tensor (0 )
211
+ ), f"Reduction of a { dataclass_type } dataclass should not change the class var"
140
212
141
213
# mapping support
142
214
reduced = apply_to_collection ({"a" : 1 , "b" : 2 }, int , lambda x : str (x ))
0 commit comments