Skip to content

Commit 2baa1e4

Browse files
committed
Fix apply_to_collection(defaultdict) (#10316)
1 parent 72288b2 commit 2baa1e4

File tree

3 files changed

+16
-7
lines changed

3 files changed

+16
-7
lines changed

CHANGELOG.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8+
## [1.5.1] - 2021-MM-DD
9+
10+
### Fixed
11+
12+
- Fixed `apply_to_collection(defaultdict)` ([#10316](https://github.com/PyTorchLightning/pytorch-lightning/issues/10316))
13+
14+
815
## [1.5.0] - 2021-11-02
916

1017
### Added
@@ -132,7 +139,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
132139
- Added support for empty `gpus` list to run on CPU ([#10246](https://github.com/PyTorchLightning/pytorch-lightning/pull/10246))
133140
- Added a warning if multiple batch sizes are found from ambiguous batch ([#10247](https://github.com/PyTorchLightning/pytorch-lightning/pull/10247))
134141

135-
136142
### Changed
137143

138144
- Trainer now raises a `MisconfigurationException` when its methods are called with `ckpt_path="best"` but a checkpoint callback isn't configured ([#9841](https://github.com/PyTorchLightning/pytorch-lightning/pull/9841))
@@ -184,7 +190,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
184190
- Enabled `on_load_checkpoint` for `LightningDataModule` for all `trainer_fn` ([#10238](https://github.com/PyTorchLightning/pytorch-lightning/pull/10238))
185191
- Allowed separate config files for parameters with class type when LightningCLI is in `subclass_mode=False` ([#10286](https://github.com/PyTorchLightning/pytorch-lightning/pull/10286))
186192

187-
188193
### Deprecated
189194

190195
- Deprecated Trainer argument `terminate_on_nan` in favor of `detect_anomaly`([#9175](https://github.com/PyTorchLightning/pytorch-lightning/pull/9175))
@@ -220,7 +225,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
220225
- Deprecated `lr_sch_names` from `LearningRateMonitor` ([#10066](https://github.com/PyTorchLightning/pytorch-lightning/pull/10066))
221226
- Deprecated `ProgressBar` callback in favor of `TQDMProgressBar` ([#10134](https://github.com/PyTorchLightning/pytorch-lightning/pull/10134))
222227

223-
224228
### Removed
225229

226230
- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
@@ -264,7 +268,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
264268
- Removed automatic patching of `{train,val,test,predict}_dataloader()` on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764))
265269
- Removed `pytorch_lightning.trainer.connectors.OptimizerConnector` ([#10120](https://github.com/PyTorchLightning/pytorch-lightning/pull/10120))
266270

267-
268271
### Fixed
269272

270273
- Fixed ImageNet evaluation in example ([#10179](https://github.com/PyTorchLightning/pytorch-lightning/pull/10179))
@@ -473,7 +476,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
473476
- Added private `prevent_trainer_and_dataloaders_deepcopy` context manager on the `LightningModule` ([#8472](https://github.com/PyTorchLightning/pytorch-lightning/pull/8472))
474477
- Added support for providing callables to the Lightning CLI instead of types ([#8400](https://github.com/PyTorchLightning/pytorch-lightning/pull/8400))
475478

476-
477479
### Changed
478480

479481
- Decoupled device parsing logic from Accelerator connector to Trainer ([#8180](https://github.com/PyTorchLightning/pytorch-lightning/pull/8180))

pytorch_lightning/utilities/apply_func.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import dataclasses
1515
import operator
1616
from abc import ABC
17-
from collections import OrderedDict
17+
from collections import defaultdict, OrderedDict
1818
from collections.abc import Mapping, Sequence
1919
from copy import copy
2020
from functools import partial
@@ -102,6 +102,8 @@ def apply_to_collection(
102102
)
103103
if include_none or v is not None:
104104
out.append((k, v))
105+
if isinstance(data, defaultdict):
106+
return elem_type(data.default_factory, OrderedDict(out))
105107
return elem_type(OrderedDict(out))
106108

107109
is_namedtuple = _is_namedtuple(data)

tests/utilities/test_apply_func.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import dataclasses
1515
import numbers
16-
from collections import namedtuple, OrderedDict
16+
from collections import defaultdict, namedtuple, OrderedDict
1717
from typing import List
1818

1919
import numpy as np
@@ -153,6 +153,11 @@ def __init__(self, initial_dict):
153153
reduced = apply_to_collection(to_reduce, int, lambda x: str(x))
154154
assert reduced == _CustomCollection({"a": "1", "b": "2", "c": "3"})
155155

156+
# defaultdict
157+
to_reduce = defaultdict(int, {"a": 1, "b": 2, "c": 3})
158+
reduced = apply_to_collection(to_reduce, int, lambda x: str(x))
159+
assert reduced == defaultdict(int, {"a": "1", "b": "2", "c": "3"})
160+
156161

157162
def test_apply_to_collection_include_none():
158163
to_reduce = [1, 2, 3.4, 5.6, 7, (8, 9.1, {10: 10})]

0 commit comments

Comments
 (0)