Skip to content

Commit f5e3907

Browse files
awaelchlicarmocca
andcommitted
Fix support for dataclasses with ClassVar/InitVar in apply_to_collection (#9702)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Carlos Mocholi <[email protected]>
1 parent e42072c commit f5e3907

File tree

6 files changed

+159
-56
lines changed

6 files changed

+159
-56
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010
### Fixed
1111

1212
- Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374))
13-
13+
- Fixed an issue where class or init-only variables of dataclasses were passed to the dataclass constructor in `utilities.apply_to_collection` ([#9702](https://github.com/PyTorchLightning/pytorch-lightning/issues/9702))
1414

1515

1616
## [1.5.1] - 2021-11-09

pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ class _Sync:
5151
fn: Optional[Callable] = None
5252
_should: bool = False
5353
rank_zero_only: bool = False
54-
op: Optional[str] = None
55-
group: Optional[Any] = None
54+
_op: Optional[str] = None
55+
_group: Optional[Any] = None
5656

5757
def __post_init__(self) -> None:
5858
self._generate_sync_fn()
@@ -67,6 +67,26 @@ def should(self, should: bool) -> None:
6767
# `self._fn` needs to be re-generated.
6868
self._generate_sync_fn()
6969

70+
@property
71+
def op(self) -> Optional[str]:
72+
return self._op
73+
74+
@op.setter
75+
def op(self, op: Optional[str]) -> None:
76+
self._op = op
77+
# `self._fn` needs to be re-generated.
78+
self._generate_sync_fn()
79+
80+
@property
81+
def group(self) -> Optional[Any]:
82+
return self._group
83+
84+
@group.setter
85+
def group(self, group: Optional[Any]) -> None:
86+
self._group = group
87+
# `self._fn` needs to be re-generated.
88+
self._generate_sync_fn()
89+
7090
def _generate_sync_fn(self) -> None:
7191
"""Used to compute the syncing function and cache it."""
7292
fn = self.no_op if self.fn is None or not self.should or self.rank_zero_only else self.fn
@@ -426,7 +446,7 @@ def log(
426446
dataloader_idx=dataloader_idx,
427447
metric_attribute=metric_attribute,
428448
)
429-
meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, rank_zero_only=rank_zero_only)
449+
meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, _group=sync_dist_group, rank_zero_only=rank_zero_only)
430450

431451
# register logged value if it doesn't exist
432452
if key not in self:

pytorch_lightning/utilities/apply_func.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from abc import ABC
1717
from collections import defaultdict, OrderedDict
1818
from collections.abc import Mapping, Sequence
19-
from copy import copy
19+
from copy import copy, deepcopy
2020
from functools import partial
2121
from typing import Any, Callable, List, Optional, Tuple, Union
2222

@@ -119,21 +119,32 @@ def apply_to_collection(
119119
return elem_type(*out) if is_namedtuple else elem_type(out)
120120

121121
if _is_dataclass_instance(data):
122-
out_dict = {}
122+
# make a deepcopy of the data,
123+
# but do not deepcopy mapped fields since the computation would
124+
# be wasted on values that likely get immediately overwritten
125+
fields = {}
126+
memo = {}
123127
for field in dataclasses.fields(data):
124-
if field.init:
128+
field_value = getattr(data, field.name)
129+
fields[field.name] = (field_value, field.init)
130+
memo[id(field_value)] = field_value
131+
result = deepcopy(data, memo=memo)
132+
# apply function to each field
133+
for field_name, (field_value, field_init) in fields.items():
134+
if field_init:
125135
v = apply_to_collection(
126-
getattr(data, field.name),
136+
field_value,
127137
dtype,
128138
function,
129139
*args,
130140
wrong_dtype=wrong_dtype,
131141
include_none=include_none,
132142
**kwargs,
133143
)
134-
if include_none or v is not None:
135-
out_dict[field.name] = v
136-
return elem_type(**out_dict)
144+
if not field_init or (not include_none and v is None): # retain old value
145+
v = getattr(data, field_name)
146+
setattr(result, field_name, v)
147+
return result
137148

138149
# data is neither of dtype, nor a collection
139150
return data

tests/core/test_results.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def _setup_ddp(rank, worldsize):
3333
def _ddp_test_fn(rank, worldsize):
3434
_setup_ddp(rank, worldsize)
3535
tensor = torch.tensor([1.0])
36-
sync = _Sync(sync_ddp_if_available, _should=True, op="SUM")
36+
sync = _Sync(sync_ddp_if_available, _should=True, _op="SUM")
3737
actual = sync(tensor)
3838
assert actual.item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors"
3939

tests/models/test_tpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def test_tpu_sync_dist():
407407
"""Test tpu spawn sync dist operation."""
408408

409409
def test_sync_dist(_):
410-
sync = _Sync(TPUSpawnPlugin().reduce, should=True, op=torch.distributed.ReduceOp.SUM)
410+
sync = _Sync(TPUSpawnPlugin().reduce, should=True, _op=torch.distributed.ReduceOp.SUM)
411411
value = torch.tensor([1.0])
412412
value = (sync(value),)
413413
assert value.item() == 8

tests/utilities/test_apply_func.py

Lines changed: 115 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
import dataclasses
1515
import numbers
1616
from collections import defaultdict, namedtuple, OrderedDict
17-
from typing import List
17+
from dataclasses import InitVar
18+
from typing import Any, ClassVar, List, Optional
1819

1920
import numpy as np
2021
import pytest
@@ -31,6 +32,12 @@ class Feature:
3132
input_ids: torch.Tensor
3233
segment_ids: np.ndarray
3334

35+
def __eq__(self, o: object) -> bool:
36+
if not isinstance(o, Feature):
37+
return NotImplemented
38+
else:
39+
return torch.equal(self.input_ids, o.input_ids) and np.equal(self.segment_ids, o.segment_ids).all()
40+
3441
@dataclasses.dataclass
3542
class ModelExample:
3643
example_ids: List[str]
@@ -41,6 +48,71 @@ class ModelExample:
4148
def __post_init__(self):
4249
self.some_constant = 7
4350

51+
def __eq__(self, o: object) -> bool:
52+
if not isinstance(o, ModelExample):
53+
return NotImplemented
54+
else:
55+
return (
56+
self.example_ids == o.example_ids
57+
and self.feature == o.feature
58+
and torch.equal(self.label, o.label)
59+
and self.some_constant == o.some_constant
60+
)
61+
62+
@dataclasses.dataclass
63+
class WithClassVar:
64+
class_var: ClassVar[int] = 0
65+
dummy: Any
66+
67+
def __eq__(self, o: object) -> bool:
68+
if not isinstance(o, WithClassVar):
69+
return NotImplemented
70+
elif isinstance(self.dummy, torch.Tensor):
71+
return torch.equal(self.dummy, o.dummy)
72+
else:
73+
return self.dummy == o.dummy
74+
75+
@dataclasses.dataclass
76+
class WithInitVar:
77+
dummy: Any
78+
override: InitVar[Optional[Any]] = None
79+
80+
def __post_init__(self, override: Optional[Any]):
81+
if override is not None:
82+
self.dummy = override
83+
84+
def __eq__(self, o: object) -> bool:
85+
if not isinstance(o, WithInitVar):
86+
return NotImplemented
87+
elif isinstance(self.dummy, torch.Tensor):
88+
return torch.equal(self.dummy, o.dummy)
89+
else:
90+
return self.dummy == o.dummy
91+
92+
@dataclasses.dataclass
93+
class WithClassAndInitVar:
94+
class_var: ClassVar[torch.Tensor] = torch.tensor(0)
95+
dummy: Any
96+
override: InitVar[Optional[Any]] = torch.tensor(1)
97+
98+
def __post_init__(self, override: Optional[Any]):
99+
if override is not None:
100+
self.dummy = override
101+
102+
def __eq__(self, o: object) -> bool:
103+
if not isinstance(o, WithClassAndInitVar):
104+
return NotImplemented
105+
elif isinstance(self.dummy, torch.Tensor):
106+
return torch.equal(self.dummy, o.dummy)
107+
else:
108+
return self.dummy == o.dummy
109+
110+
model_example = ModelExample(
111+
example_ids=["i-1", "i-2", "i-3"],
112+
feature=Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])),
113+
label=torch.tensor([7.0, 8.0, 9.0]),
114+
)
115+
44116
to_reduce = {
45117
"a": torch.tensor([1.0]), # Tensor
46118
"b": [torch.tensor([2.0])], # list
@@ -50,13 +122,18 @@ def __post_init__(self):
50122
"f": "this_is_a_dummy_str", # string
51123
"g": 12.0, # number
52124
"h": Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])), # dataclass
53-
"i": ModelExample(
54-
example_ids=["i-1", "i-2", "i-3"],
55-
feature=Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])),
56-
label=torch.tensor([7.0, 8.0, 9.0]),
57-
), # nested dataclass
125+
"i": model_example, # nested dataclass
126+
"j": WithClassVar(torch.arange(3)), # dataclass with class variable
127+
"k": WithInitVar("this_gets_overridden", torch.tensor([2.0])), # dataclass with init-only variable
128+
"l": WithClassAndInitVar(model_example, None), # nested dataclass with class and init-only variables
58129
}
59130

131+
model_example_result = ModelExample(
132+
example_ids=["i-1", "i-2", "i-3"],
133+
feature=Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=np.array([8.0, 10.0, 12.0])),
134+
label=torch.tensor([14.0, 16.0, 18.0]),
135+
)
136+
60137
expected_result = {
61138
"a": torch.tensor([2.0]),
62139
"b": [torch.tensor([4.0])],
@@ -66,32 +143,31 @@ def __post_init__(self):
66143
"f": "this_is_a_dummy_str",
67144
"g": 24.0,
68145
"h": Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=np.array([8.0, 10.0, 12.0])),
69-
"i": ModelExample(
70-
example_ids=["i-1", "i-2", "i-3"],
71-
feature=Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=np.array([8.0, 10.0, 12.0])),
72-
label=torch.tensor([14.0, 16.0, 18.0]),
73-
),
146+
"i": model_example_result,
147+
"j": WithClassVar(torch.arange(0, 6, 2)),
148+
"k": WithInitVar(torch.tensor([4.0])),
149+
"l": WithClassAndInitVar(model_example_result, None),
74150
}
75151

76152
reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number, np.ndarray), lambda x: x * 2)
77153

78-
assert isinstance(reduced, dict), " Type Consistency of dict not preserved"
154+
assert isinstance(reduced, dict), "Type Consistency of dict not preserved"
79155
assert all(x in reduced for x in to_reduce), "Not all entries of the dict were preserved"
80156
assert all(
81157
isinstance(reduced[k], type(expected_result[k])) for k in to_reduce
82158
), "At least one type was not correctly preserved"
83159

84160
assert isinstance(reduced["a"], torch.Tensor), "Reduction Result of a Tensor should be a Tensor"
85-
assert torch.allclose(expected_result["a"], reduced["a"]), "Reduction of a tensor does not yield the expected value"
161+
assert torch.equal(expected_result["a"], reduced["a"]), "Reduction of a tensor does not yield the expected value"
86162

87163
assert isinstance(reduced["b"], list), "Reduction Result of a list should be a list"
88164
assert all(
89-
torch.allclose(x, y) for x, y in zip(reduced["b"], expected_result["b"])
165+
torch.equal(x, y) for x, y in zip(reduced["b"], expected_result["b"])
90166
), "At least one value of list reduction did not come out as expected"
91167

92168
assert isinstance(reduced["c"], tuple), "Reduction Result of a tuple should be a tuple"
93169
assert all(
94-
torch.allclose(x, y) for x, y in zip(reduced["c"], expected_result["c"])
170+
torch.equal(x, y) for x, y in zip(reduced["c"], expected_result["c"])
95171
), "At least one value of tuple reduction did not come out as expected"
96172

97173
assert isinstance(reduced["d"], ntc), "Type Consistency for named tuple not given"
@@ -109,34 +185,30 @@ def __post_init__(self):
109185
assert isinstance(reduced["g"], numbers.Number), "Reduction of a number should result in a number"
110186
assert reduced["g"] == expected_result["g"], "Reduction of a number did not yield the desired result"
111187

112-
assert dataclasses.is_dataclass(reduced["h"]) and not isinstance(
113-
reduced["h"], type
114-
), "Reduction of a dataclass should result in a dataclass"
115-
assert torch.allclose(
116-
reduced["h"].input_ids, expected_result["h"].input_ids
117-
), "Reduction of a dataclass did not yield the desired result"
118-
assert np.allclose(
119-
reduced["h"].segment_ids, expected_result["h"].segment_ids
120-
), "Reduction of a dataclass did not yield the desired result"
121-
122-
assert dataclasses.is_dataclass(reduced["i"]) and not isinstance(
123-
reduced["i"], type
124-
), "Reduction of a dataclass should result in a dataclass"
125-
assert dataclasses.is_dataclass(reduced["i"].feature) and not isinstance(
126-
reduced["i"].feature, type
127-
), "Reduction of a nested dataclass should result in a nested dataclass"
128-
assert (
129-
reduced["i"].example_ids == expected_result["i"].example_ids
130-
), "Reduction of a nested dataclass did not yield the desired result"
131-
assert torch.allclose(
132-
reduced["i"].label, expected_result["i"].label
133-
), "Reduction of a nested dataclass did not yield the desired result"
134-
assert torch.allclose(
135-
reduced["i"].feature.input_ids, expected_result["i"].feature.input_ids
136-
), "Reduction of a nested dataclass did not yield the desired result"
137-
assert np.allclose(
138-
reduced["i"].feature.segment_ids, expected_result["i"].feature.segment_ids
139-
), "Reduction of a nested dataclass did not yield the desired result"
188+
def _assert_dataclass_reduction(actual, expected, dataclass_type: str = ""):
189+
assert dataclasses.is_dataclass(actual) and not isinstance(
190+
actual, type
191+
), f"Reduction of a {dataclass_type} dataclass should result in a dataclass"
192+
for field in dataclasses.fields(actual):
193+
if dataclasses.is_dataclass(field.type):
194+
_assert_dataclass_reduction(getattr(actual, field.name), getattr(expected, field.name), "nested")
195+
assert actual == expected, f"Reduction of a {dataclass_type} dataclass did not yield the desired result"
196+
197+
_assert_dataclass_reduction(reduced["h"], expected_result["h"])
198+
199+
_assert_dataclass_reduction(reduced["i"], expected_result["i"])
200+
201+
dataclass_type = "ClassVar-containing"
202+
_assert_dataclass_reduction(reduced["j"], expected_result["j"], dataclass_type)
203+
assert WithClassVar.class_var == 0, f"Reduction of a {dataclass_type} dataclass should not change the class var"
204+
205+
_assert_dataclass_reduction(reduced["k"], expected_result["k"], "InitVar-containing")
206+
207+
dataclass_type = "Class-and-InitVar-containing"
208+
_assert_dataclass_reduction(reduced["l"], expected_result["l"], dataclass_type)
209+
assert torch.equal(
210+
WithClassAndInitVar.class_var, torch.tensor(0)
211+
), f"Reduction of a {dataclass_type} dataclass should not change the class var"
140212

141213
# mapping support
142214
reduced = apply_to_collection({"a": 1, "b": 2}, int, lambda x: str(x))

0 commit comments

Comments
 (0)