Skip to content

Commit ff405eb

Browse files
committed
change BaseFinetuning.__apply_mapping_to_param_groups to protected instead of private and minor refactor of finetuning_callback test
1 parent 8105b09 commit ff405eb

File tree

2 files changed

+30
-28
lines changed

2 files changed

+30
-28
lines changed

pytorch_lightning/callbacks/finetuning.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
9898
if self._restarting:
9999
named_parameters = dict(pl_module.named_parameters())
100100
for opt_idx, optimizer in enumerate(trainer.optimizers):
101-
param_groups = self.__apply_mapping_to_param_groups(
101+
param_groups = self._apply_mapping_to_param_groups(
102102
self._internal_optimizer_metadata[opt_idx], named_parameters
103103
)
104104
optimizer.param_groups = param_groups
@@ -244,7 +244,7 @@ def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module:
244244
self.freeze_before_training(pl_module)
245245

246246
@staticmethod
247-
def __apply_mapping_to_param_groups(param_groups: List[Dict[str, Any]], mapping: dict) -> List[Dict[str, Any]]:
247+
def _apply_mapping_to_param_groups(param_groups: List[Dict[str, Any]], mapping: dict) -> List[Dict[str, Any]]:
248248
output = []
249249
for g in param_groups:
250250
# skip params to save memory
@@ -262,13 +262,13 @@ def _store(
262262
) -> None:
263263
mapping = {p: n for n, p in pl_module.named_parameters()}
264264
if opt_idx not in self._internal_optimizer_metadata:
265-
self._internal_optimizer_metadata[opt_idx] = self.__apply_mapping_to_param_groups(
265+
self._internal_optimizer_metadata[opt_idx] = self._apply_mapping_to_param_groups(
266266
current_param_groups, mapping
267267
)
268268
elif num_param_groups != len(current_param_groups):
269269
# save new param_groups possibly created by the users.
270270
self._internal_optimizer_metadata[opt_idx].extend(
271-
self.__apply_mapping_to_param_groups(current_param_groups[num_param_groups:], mapping)
271+
self._apply_mapping_to_param_groups(current_param_groups[num_param_groups:], mapping)
272272
)
273273

274274
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:

tests/callbacks/test_finetuning_callback.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -287,34 +287,36 @@ def configure_optimizers(self):
287287
trainer.fit(model)
288288

289289

290-
def test_complex_nested_model():
291-
"""Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters
292-
directly themselves rather than exclusively their submodules containing parameters."""
290+
class ConvBlock(nn.Module):
291+
def __init__(self, in_channels, out_channels):
292+
super().__init__()
293+
self.conv = nn.Conv2d(in_channels, out_channels, 3)
294+
self.act = nn.ReLU()
295+
self.bn = nn.BatchNorm2d(out_channels)
293296

294-
class ConvBlock(nn.Module):
295-
def __init__(self, in_channels, out_channels):
296-
super().__init__()
297-
self.conv = nn.Conv2d(in_channels, out_channels, 3)
298-
self.act = nn.ReLU()
299-
self.bn = nn.BatchNorm2d(out_channels)
297+
def forward(self, x):
298+
x = self.conv(x)
299+
x = self.act(x)
300+
return self.bn(x)
300301

301-
def forward(self, x):
302-
x = self.conv(x)
303-
x = self.act(x)
304-
return self.bn(x)
305302

306-
class ConvBlockParam(nn.Module):
307-
def __init__(self, in_channels, out_channels):
308-
super().__init__()
309-
self.module_dict = nn.ModuleDict({"conv": nn.Conv2d(in_channels, out_channels, 3), "act": nn.ReLU()})
310-
# add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
311-
self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float))
312-
self.bn = nn.BatchNorm2d(out_channels)
303+
class ConvBlockParam(nn.Module):
304+
def __init__(self, in_channels, out_channels):
305+
super().__init__()
306+
self.module_dict = nn.ModuleDict({"conv": nn.Conv2d(in_channels, out_channels, 3), "act": nn.ReLU()})
307+
# add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
308+
self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float))
309+
self.bn = nn.BatchNorm2d(out_channels)
313310

314-
def forward(self, x):
315-
x = self.module_dict["conv"](x)
316-
x = self.module_dict["act"](x)
317-
return self.bn(x)
311+
def forward(self, x):
312+
x = self.module_dict["conv"](x)
313+
x = self.module_dict["act"](x)
314+
return self.bn(x)
315+
316+
317+
def test_complex_nested_model():
318+
"""Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters
319+
directly themselves rather than exclusively their submodules containing parameters."""
318320

319321
model = nn.Sequential(
320322
OrderedDict(

0 commit comments

Comments
 (0)