Skip to content

Commit 2a372e3

Browse files
authored
Fix module dict in base finetuning (#8170)
* Fix module dict in base finetuning * Update CHANGELOG.md
1 parent b978d2a commit 2a372e3

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
315315
- Fixed a DDP info message that was never shown ([#8111](https://github.com/PyTorchLightning/pytorch-lightning/pull/8111))
316316

317317

318+
- Fixed a bug where an infinite recursion would be triggered when using the `BaseFinetuning` callback on a model that contains a `ModuleDict` ([#8170](https://github.com/PyTorchLightning/pytorch-lightning/pull/8170))
319+
318320
## [1.3.7] - 2021-06-22
319321

320322
- Fixed a bug where skipping an optimizer while using amp causes amp to trigger an assertion error ([#7975](https://github.com/PyTorchLightning/pytorch-lightning/pull/7975))

pytorch_lightning/callbacks/finetuning.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union
2121

2222
import torch
23-
from torch.nn import Module
23+
from torch.nn import Module, ModuleDict
2424
from torch.nn.modules.batchnorm import _BatchNorm
2525
from torch.optim.optimizer import Optimizer
2626

@@ -114,6 +114,9 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -
114114
Returns:
115115
List of modules
116116
"""
117+
if isinstance(modules, ModuleDict):
118+
modules = modules.values()
119+
117120
if isinstance(modules, Iterable):
118121
_modules = []
119122
for m in modules:

tests/callbacks/test_finetuning_callback.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -331,15 +331,17 @@ class ConvBlockParam(nn.Module):
331331

332332
def __init__(self, in_channels, out_channels):
333333
super().__init__()
334-
self.conv = nn.Conv2d(in_channels, out_channels, 3)
335-
self.act = nn.ReLU()
334+
self.module_dict = nn.ModuleDict({
335+
"conv": nn.Conv2d(in_channels, out_channels, 3),
336+
"act": nn.ReLU(),
337+
})
336338
# add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
337339
self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float))
338340
self.bn = nn.BatchNorm2d(out_channels)
339341

340342
def forward(self, x):
341-
x = self.conv(x)
342-
x = self.act(x)
343+
x = self.module_dict["conv"](x)
344+
x = self.module_dict["act"](x)
343345
return self.bn(x)
344346

345347
model = nn.Sequential(
@@ -353,7 +355,7 @@ def forward(self, x):
353355
assert len(BaseFinetuning.flatten_modules(model)) == 10
354356

355357
BaseFinetuning.freeze(model.encoder, train_bn=True)
356-
assert not model.encoder[0].conv.weight.requires_grad # Validate a leaf module parameter is frozen
358+
assert not model.encoder[0].module_dict["conv"].weight.requires_grad # Validate a leaf module parameter is frozen
357359
assert not model.encoder[0].parent_param.requires_grad # Validate the parent module parameter is frozen
358360
assert model.encoder[0].bn.weight.requires_grad
359361

0 commit comments

Comments
 (0)