Skip to content

Commit d4de8e2

Browse files
authored
Count number of modules in train/eval mode in ModelSummary (#20159)
1 parent 60fe36a commit d4de8e2

File tree

10 files changed

+74
-7
lines changed

10 files changed

+74
-7
lines changed

docs/source-pytorch/advanced/transfer_learning.rst

+1
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ Here's a model that uses `Huggingface transformers <https://github.com/huggingfa
116116
super().__init__()
117117

118118
self.bert = BertModel.from_pretrained("bert-base-cased", output_attentions=True)
119+
self.bert.train()
119120
self.W = nn.Linear(bert.config.hidden_size, 3)
120121
self.num_classes = 3
121122

src/lightning/pytorch/CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1717

1818
- The `TQDMProgressBar` now provides an option to retain prior training epoch bars ([#19578](https://github.com/Lightning-AI/pytorch-lightning/pull/19578))
1919

20+
- Added the count of modules in train and eval mode to the printed `ModelSummary` table ([#20159](https://github.com/Lightning-AI/pytorch-lightning/pull/20159))
21+
2022
### Changed
2123

2224
- Triggering KeyboardInterrupt (Ctrl+C) during `.fit()`, `.evaluate()`, `.test()` or `.predict()` now terminates all processes launched by the Trainer and exits the program ([#19976](https://github.com/Lightning-AI/pytorch-lightning/pull/19976))

src/lightning/pytorch/callbacks/model_summary.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,17 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
6666
total_parameters = model_summary.total_parameters
6767
trainable_parameters = model_summary.trainable_parameters
6868
model_size = model_summary.model_size
69+
total_training_modes = model_summary.total_training_modes
6970

7071
if trainer.is_global_zero:
71-
self.summarize(summary_data, total_parameters, trainable_parameters, model_size, **self._summarize_kwargs)
72+
self.summarize(
73+
summary_data,
74+
total_parameters,
75+
trainable_parameters,
76+
model_size,
77+
total_training_modes,
78+
**self._summarize_kwargs,
79+
)
7280

7381
def _summary(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Union[DeepSpeedSummary, Summary]:
7482
from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy
@@ -83,12 +91,14 @@ def summarize(
8391
total_parameters: int,
8492
trainable_parameters: int,
8593
model_size: float,
94+
total_training_modes: Dict[str, int],
8695
**summarize_kwargs: Any,
8796
) -> None:
8897
summary_table = _format_summary_table(
8998
total_parameters,
9099
trainable_parameters,
91100
model_size,
101+
total_training_modes,
92102
*summary_data,
93103
)
94104
log.info("\n" + summary_table)

src/lightning/pytorch/callbacks/rich_model_summary.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, List, Tuple
14+
from typing import Any, Dict, List, Tuple
1515

1616
from typing_extensions import override
1717

@@ -71,6 +71,7 @@ def summarize(
7171
total_parameters: int,
7272
trainable_parameters: int,
7373
model_size: float,
74+
total_training_modes: Dict[str, int],
7475
**summarize_kwargs: Any,
7576
) -> None:
7677
from rich import get_console
@@ -110,5 +111,7 @@ def summarize(
110111
grid.add_row(f"[bold]Non-trainable params[/]: {parameters[1]}")
111112
grid.add_row(f"[bold]Total params[/]: {parameters[2]}")
112113
grid.add_row(f"[bold]Total estimated model params size (MB)[/]: {parameters[3]}")
114+
grid.add_row(f"[bold]Modules in train mode[/]: {total_training_modes['train']}")
115+
grid.add_row(f"[bold]Modules in eval mode[/]: {total_training_modes['eval']}")
113116

114117
console.print(grid)

src/lightning/pytorch/utilities/model_summary/model_summary.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@ class ModelSummary:
187187
0 Non-trainable params
188188
132 K Total params
189189
0.530 Total estimated model params size (MB)
190+
3 Modules in train mode
191+
0 Modules in eval mode
190192
>>> ModelSummary(model, max_depth=-1) # doctest: +NORMALIZE_WHITESPACE
191193
| Name | Type | Params | Mode | In sizes | Out sizes
192194
----------------------------------------------------------------------
@@ -198,6 +200,8 @@ class ModelSummary:
198200
0 Non-trainable params
199201
132 K Total params
200202
0.530 Total estimated model params size (MB)
203+
3 Modules in train mode
204+
0 Modules in eval mode
201205
202206
"""
203207

@@ -252,6 +256,12 @@ def param_nums(self) -> List[int]:
252256
def training_modes(self) -> List[bool]:
253257
return [layer.training for layer in self._layer_summary.values()]
254258

259+
@property
260+
def total_training_modes(self) -> Dict[str, int]:
261+
modes = [layer.training for layer in self._model.modules()]
262+
modes = modes[1:] # exclude the root module
263+
return {"train": modes.count(True), "eval": modes.count(False)}
264+
255265
@property
256266
def total_parameters(self) -> int:
257267
return sum(p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters())
@@ -351,8 +361,9 @@ def __str__(self) -> str:
351361
total_parameters = self.total_parameters
352362
trainable_parameters = self.trainable_parameters
353363
model_size = self.model_size
364+
total_training_modes = self.total_training_modes
354365

355-
return _format_summary_table(total_parameters, trainable_parameters, model_size, *arrays)
366+
return _format_summary_table(total_parameters, trainable_parameters, model_size, total_training_modes, *arrays)
356367

357368
def __repr__(self) -> str:
358369
return str(self)
@@ -372,6 +383,7 @@ def _format_summary_table(
372383
total_parameters: int,
373384
trainable_parameters: int,
374385
model_size: float,
386+
total_training_modes: Dict[str, int],
375387
*cols: Tuple[str, List[str]],
376388
) -> str:
377389
"""Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big
@@ -408,6 +420,10 @@ def _format_summary_table(
408420
summary += "Total params"
409421
summary += "\n" + s.format(get_formatted_model_size(model_size), 10)
410422
summary += "Total estimated model params size (MB)"
423+
summary += "\n" + s.format(total_training_modes["train"], 10)
424+
summary += "Modules in train mode"
425+
summary += "\n" + s.format(total_training_modes["eval"], 10)
426+
summary += "Modules in eval mode"
411427

412428
return summary
413429

tests/tests_pytorch/callbacks/test_early_stopping.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def on_train_epoch_end(self, trainer, pl_module):
5858
self.saved_states.append(self.state_dict().copy())
5959

6060

61-
@RunIf(sklearn=True)
61+
@RunIf(sklearn=True, skip_windows=True) # Flaky test on Windows for unknown reasons
6262
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
6363
def test_resume_early_stopping_from_checkpoint(tmp_path):
6464
"""Prevent regressions to bugs:

tests/tests_pytorch/callbacks/test_model_summary.py

+3
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def summarize(
4949
total_parameters: int,
5050
trainable_parameters: int,
5151
model_size: float,
52+
total_training_modes,
5253
**summarize_kwargs: Any,
5354
) -> None:
5455
assert summary_data[1][0] == "Name"
@@ -64,6 +65,8 @@ def summarize(
6465
assert summary_data[4][0] == "Mode"
6566
assert summary_data[4][1][0] == "train"
6667

68+
assert total_training_modes == {"train": 1, "eval": 0}
69+
6770
model = BoringModel()
6871
trainer = Trainer(default_root_dir=tmp_path, callbacks=CustomModelSummary(), max_steps=1)
6972

tests/tests_pytorch/callbacks/test_rich_model_summary.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,13 @@ def example_input_array(self) -> Any:
5656
summary = summarize(model)
5757
summary_data = summary._get_summary_data()
5858

59-
model_summary.summarize(summary_data=summary_data, total_parameters=1, trainable_parameters=1, model_size=1)
59+
model_summary.summarize(
60+
summary_data=summary_data,
61+
total_parameters=1,
62+
trainable_parameters=1,
63+
model_size=1,
64+
total_training_modes=summary.total_training_modes,
65+
)
6066

6167
# ensure that summary was logged + the breakdown of model parameters
6268
assert mock_console.call_count == 2

tests/tests_pytorch/core/test_datamodules.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
218218
assert dm.my_state_dict == {"my": "state_dict"}
219219

220220

221-
@RunIf(sklearn=True)
221+
@RunIf(sklearn=True, skip_windows=True) # Flaky test on Windows for unknown reasons
222222
def test_full_loop(tmp_path):
223223
seed_everything(7)
224224

tests/tests_pytorch/utilities/test_model_summary.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,29 @@ def forward(self, x):
423423
assert not model.layer2.training
424424

425425

426+
def test_total_training_modes():
427+
"""Test that the `total_training_modes` counts the modules in 'train' and 'eval' mode, excluding the root
428+
module."""
429+
430+
class ModelWithoutChildren(LightningModule):
431+
pass
432+
433+
summary = ModelSummary(ModelWithoutChildren())
434+
assert summary.total_training_modes == {"train": 0, "eval": 0}
435+
436+
model = DeepNestedModel()
437+
summary = ModelSummary(model)
438+
assert summary.total_training_modes == {"train": 19, "eval": 0}
439+
assert sum(summary.total_training_modes.values()) == len(list(model.modules())) - 1
440+
441+
model = DeepNestedModel()
442+
summary = ModelSummary(model)
443+
model.branch1[1][0].eval()
444+
model.branch2.eval()
445+
assert summary.total_training_modes == {"train": 17, "eval": 2}
446+
assert sum(summary.total_training_modes.values()) == len(list(model.modules())) - 1
447+
448+
426449
def test_summary_training_mode():
427450
"""Test that the model summary captures the training mode on all submodules."""
428451
model = DeepNestedModel()
@@ -436,6 +459,7 @@ def test_summary_training_mode():
436459
"eval", # branch2
437460
"train", # head
438461
]
462+
assert summary.total_training_modes == {"train": 17, "eval": 2}
439463

440464
summary = summarize(model, max_depth=-1)
441465
expected_eval = {"branch1.1.0", "branch2"}
@@ -445,5 +469,7 @@ def test_summary_training_mode():
445469
# A model with params not belonging to a layer
446470
model = NonLayerParamsModel()
447471
model.layer.eval()
448-
summary_data = OrderedDict(summarize(model)._get_summary_data())
472+
summary = summarize(model)
473+
summary_data = OrderedDict(summary._get_summary_data())
449474
assert summary_data["Mode"] == ["eval", "n/a"]
475+
assert summary.total_training_modes == {"train": 0, "eval": 1}

0 commit comments

Comments
 (0)