Skip to content

Commit 7b4df7b

Browse files
ronifawaelchli
andauthored
Fix issue with no-init dataclass fields in move_to_device (#9963)
Co-authored-by: ronif <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent e5dfdf3 commit 7b4df7b

File tree

3 files changed

+20
-12
lines changed

3 files changed

+20
-12
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
544544
- Fixed use of `LightningCLI` in computer_vision_fine_tuning.py example ([#9934](https://github.com/PyTorchLightning/pytorch-lightning/pull/9934))
545545

546546

547+
- Fixed issue with non-init dataclass fields in `apply_to_collection` ([#9963](https://github.com/PyTorchLightning/pytorch-lightning/issues/9963))
548+
549+
547550
## [1.4.9] - 2021-09-30
548551

549552
- Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704))

pytorch_lightning/utilities/apply_func.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -118,18 +118,19 @@ def apply_to_collection(
118118

119119
if _is_dataclass_instance(data):
120120
out_dict = {}
121-
for field in data.__dataclass_fields__:
122-
v = apply_to_collection(
123-
getattr(data, field),
124-
dtype,
125-
function,
126-
*args,
127-
wrong_dtype=wrong_dtype,
128-
include_none=include_none,
129-
**kwargs,
130-
)
131-
if include_none or v is not None:
132-
out_dict[field] = v
121+
for field in dataclasses.fields(data):
122+
if field.init:
123+
v = apply_to_collection(
124+
getattr(data, field.name),
125+
dtype,
126+
function,
127+
*args,
128+
wrong_dtype=wrong_dtype,
129+
include_none=include_none,
130+
**kwargs,
131+
)
132+
if include_none or v is not None:
133+
out_dict[field.name] = v
133134
return elem_type(**out_dict)
134135

135136
# data is neither of dtype, nor a collection

tests/utilities/test_apply_func.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ class ModelExample:
3636
example_ids: List[str]
3737
feature: Feature
3838
label: torch.Tensor
39+
some_constant: int = dataclasses.field(init=False)
40+
41+
def __post_init__(self):
42+
self.some_constant = 7
3943

4044
to_reduce = {
4145
"a": torch.tensor([1.0]), # Tensor

0 commit comments

Comments
 (0)