Skip to content

Commit d773407

Browse files
kaushikb11ethanwharrispre-commit-ci[bot]
authored
feat: Add ModelSummary Callback (#9344)
Co-authored-by: Ethan Harris <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4f8c3ba commit d773407

File tree

11 files changed

+269
-33
lines changed

11 files changed

+269
-33
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
119119
- Added `remove_checkpoint` to `CheckpointIO` plugin by moving the responsibility from `ModelCheckpoint` Callback ([#9373](https://github.com/PyTorchLightning/pytorch-lightning/pull/9373))
120120

121121

122+
- Added `ModelSummary` callback ([#9344](https://github.com/PyTorchLightning/pytorch-lightning/pull/9344))
123+
124+
122125
### Changed
123126

124127
- Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770))

pytorch_lightning/callbacks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pytorch_lightning.callbacks.lambda_function import LambdaCallback
2020
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
2121
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
22+
from pytorch_lightning.callbacks.model_summary import ModelSummary
2223
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
2324
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar
2425
from pytorch_lightning.callbacks.pruning import ModelPruning
@@ -39,6 +40,7 @@
3940
"LearningRateMonitor",
4041
"ModelCheckpoint",
4142
"ModelPruning",
43+
"ModelSummary",
4244
"BasePredictionWriter",
4345
"ProgressBar",
4446
"ProgressBarBase",
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Model Summary
16+
=============
17+
18+
Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`.
19+
20+
The string representation of this summary prints a table with columns containing
21+
the name, type and number of parameters for each layer.
22+
23+
"""
24+
import logging
25+
from typing import List, Optional, Union
26+
27+
import pytorch_lightning as pl
28+
from pytorch_lightning.callbacks.base import Callback
29+
from pytorch_lightning.utilities.model_summary import _format_summary_table, summarize
30+
31+
log = logging.getLogger(__name__)
32+
33+
34+
class ModelSummary(Callback):
35+
r"""
36+
Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`.
37+
38+
Args:
39+
max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the
40+
layer summary off.
41+
42+
Example::
43+
44+
>>> from pytorch_lightning import Trainer
45+
>>> from pytorch_lightning.callbacks import ModelSummary
46+
>>> trainer = Trainer(callbacks=[ModelSummary(max_depth=1)])
47+
"""
48+
49+
def __init__(self, max_depth: Optional[int] = 1):
50+
self._max_depth: int = max_depth
51+
52+
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
53+
if trainer.is_global_zero and self._max_depth is not None and not trainer.testing:
54+
model_summary = summarize(pl_module, max_depth=self._max_depth)
55+
56+
summary_data = model_summary._get_summary_data()
57+
total_parameters = model_summary.total_parameters
58+
trainable_parameters = model_summary.trainable_parameters
59+
model_size = model_summary.model_size
60+
61+
self.summarize(summary_data, total_parameters, trainable_parameters, model_size)
62+
63+
@staticmethod
64+
def summarize(
65+
summary_data: List[List[Union[str, List[str]]]],
66+
total_parameters: int,
67+
trainable_parameters: int,
68+
model_size: float,
69+
) -> None:
70+
summary_table = _format_summary_table(total_parameters, trainable_parameters, model_size, *summary_data)
71+
72+
log.info("\n" + summary_table)

pytorch_lightning/trainer/connectors/callback_connector.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
from datetime import timedelta
1616
from typing import Dict, List, Optional, Union
1717

18-
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase
18+
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ModelSummary, ProgressBar, ProgressBarBase
1919
from pytorch_lightning.callbacks.timer import Timer
20-
from pytorch_lightning.utilities import rank_zero_info
20+
from pytorch_lightning.utilities import ModelSummaryMode, rank_zero_info
2121
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2222
from pytorch_lightning.utilities.warnings import rank_zero_deprecation
2323

@@ -34,6 +34,7 @@ def on_trainer_init(
3434
process_position: int,
3535
default_root_dir: Optional[str],
3636
weights_save_path: Optional[str],
37+
weights_summary: Optional[str],
3738
stochastic_weight_avg: bool,
3839
max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
3940
):
@@ -58,6 +59,8 @@ def on_trainer_init(
5859
# responsible to stop the training when max_time is reached.
5960
self._configure_timer_callback(max_time)
6061

62+
self._configure_model_summary_callback(weights_summary)
63+
6164
# init progress bar
6265
if process_position != 0:
6366
rank_zero_deprecation(
@@ -89,6 +92,19 @@ def _configure_checkpoint_callbacks(self, checkpoint_callback: bool) -> None:
8992
if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True:
9093
self.trainer.callbacks.append(ModelCheckpoint())
9194

95+
def _configure_model_summary_callback(self, weights_summary: Optional[str] = None) -> None:
96+
if any(isinstance(cb, ModelSummary) for cb in self.trainer.callbacks):
97+
return
98+
if weights_summary is not None:
99+
if weights_summary not in ModelSummaryMode.supported_types():
100+
raise MisconfigurationException(
101+
f"`weights_summary` can be None, {', '.join(ModelSummaryMode.supported_types())}",
102+
f" but got {weights_summary}",
103+
)
104+
max_depth = ModelSummaryMode.get_max_depth(weights_summary)
105+
model_summary = ModelSummary(max_depth=max_depth)
106+
self.trainer.callbacks.append(model_summary)
107+
92108
def _configure_swa_callbacks(self):
93109
if not self.trainer._stochastic_weight_avg:
94110
return

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@
7979
from pytorch_lightning.utilities.exceptions import MisconfigurationException
8080
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_9
8181
from pytorch_lightning.utilities.model_helpers import is_overridden
82-
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
8382
from pytorch_lightning.utilities.seed import reset_seed
8483
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
8584

@@ -407,11 +406,6 @@ def __init__(
407406
# default .predict() loop
408407
self.predict_loop = PredictionLoop()
409408

410-
# training state
411-
if weights_summary is not None and weights_summary not in ModelSummary.MODES:
412-
raise MisconfigurationException(
413-
f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, but got {weights_summary}"
414-
)
415409
self.weights_summary = weights_summary
416410

417411
# init callbacks
@@ -423,6 +417,7 @@ def __init__(
423417
process_position,
424418
default_root_dir,
425419
weights_save_path,
420+
self.weights_summary,
426421
stochastic_weight_avg,
427422
max_time,
428423
)
@@ -1108,11 +1103,6 @@ def _pre_training_routine(self):
11081103
# --------------------------
11091104
self.call_hook("on_pretrain_routine_start")
11101105

1111-
# print model summary
1112-
if self.is_global_zero and self.weights_summary is not None and not self.testing:
1113-
max_depth = ModelSummary.MODES[self.weights_summary]
1114-
summarize(self.lightning_module, max_depth=max_depth)
1115-
11161106
self.call_hook("on_pretrain_routine_end")
11171107

11181108
def _run_train(self) -> None:

pytorch_lightning/utilities/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
DistributedType,
2424
GradClipAlgorithmType,
2525
LightningEnum,
26+
ModelSummaryMode,
2627
)
2728
from pytorch_lightning.utilities.grads import grad_norm # noqa: F401
2829
from pytorch_lightning.utilities.imports import ( # noqa: F401

pytorch_lightning/utilities/enums.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def supported_types() -> List[str]:
7373

7474

7575
class DistributedType(LightningEnum):
76-
"""Define type of ditributed computing.
76+
"""Define type of distributed computing.
7777
7878
>>> # you can match the type with string
7979
>>> DistributedType.DDP == 'ddp'
@@ -147,3 +147,35 @@ class AutoRestartBatchKeys(LightningEnum):
147147
"""Defines special dictionary keys used to track captured dataset state with multiple workers."""
148148

149149
PL_RESTART_META = "__pl_restart_meta"
150+
151+
152+
class ModelSummaryMode(LightningEnum):
153+
# TODO: remove in v1.6 (as `mode` would be deprecated for `max_depth`)
154+
"""Define the Model Summary mode to be used.
155+
156+
Can be one of
157+
- `top`: only the top-level modules will be recorded (the children of the root module)
158+
- `full`: summarizes all layers and their submodules in the root module
159+
160+
>>> # you can match the type with string
161+
>>> ModelSummaryMode.TOP == 'TOP'
162+
True
163+
>>> # which is case invariant
164+
>>> ModelSummaryMode.TOP in ('top', 'FULL')
165+
True
166+
"""
167+
168+
TOP = "top"
169+
FULL = "full"
170+
171+
@staticmethod
172+
def get_max_depth(mode: str) -> int:
173+
if mode == ModelSummaryMode.TOP:
174+
return 1
175+
if mode == ModelSummaryMode.FULL:
176+
return -1
177+
raise ValueError(f"`mode` can be {', '.join(list(ModelSummaryMode))}, got {mode}.")
178+
179+
@staticmethod
180+
def supported_types() -> List[str]:
181+
return [x.value for x in ModelSummaryMode]

pytorch_lightning/utilities/model_summary.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from torch.utils.hooks import RemovableHandle
2424

2525
import pytorch_lightning as pl
26-
from pytorch_lightning.utilities import AMPType, DeviceType, rank_zero_deprecation
26+
from pytorch_lightning.utilities import AMPType, DeviceType, ModelSummaryMode, rank_zero_deprecation
2727
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2828
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
2929
from pytorch_lightning.utilities.warnings import WarningCache
@@ -185,21 +185,21 @@ class ModelSummary:
185185
0.530 Total estimated model params size (MB)
186186
"""
187187

188-
MODES = dict(top=1, full=-1) # TODO: remove in v1.6
189-
190188
def __init__(self, model, mode: Optional[str] = None, max_depth: Optional[int] = 1):
191189
self._model = model
192190

193191
# temporary mapping from mode to max_depth
194192
if max_depth is None or mode is not None:
195-
if mode in ModelSummary.MODES:
196-
max_depth = ModelSummary.MODES[mode]
193+
if mode in ModelSummaryMode.supported_types():
194+
max_depth = ModelSummaryMode.get_max_depth(mode)
197195
rank_zero_deprecation(
198196
"Argument `mode` in `ModelSummary` is deprecated in v1.4"
199197
f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behaviour."
200198
)
201199
else:
202-
raise MisconfigurationException(f"`mode` can be {', '.join(ModelSummary.MODES)}, got {mode}.")
200+
raise MisconfigurationException(
201+
f"`mode` can be {', '.join(ModelSummaryMode.supported_types())}, got {mode}."
202+
)
203203

204204
if not isinstance(max_depth, int) or max_depth < -1:
205205
raise ValueError(f"`max_depth` can be -1, 0 or > 0, got {max_depth}.")
@@ -295,7 +295,7 @@ def _forward_example_input(self) -> None:
295295
model(input_)
296296
model.train(mode) # restore mode of module
297297

298-
def __str__(self):
298+
def _get_summary_data(self):
299299
"""Makes a summary listing with:
300300
301301
Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size
@@ -310,6 +310,11 @@ def __str__(self):
310310
arrays.append(["In sizes", self.in_sizes])
311311
arrays.append(["Out sizes", self.out_sizes])
312312

313+
return arrays
314+
315+
def __str__(self):
316+
arrays = self._get_summary_data()
317+
313318
total_parameters = self.total_parameters
314319
trainable_parameters = self.trainable_parameters
315320
model_size = self.model_size
@@ -445,16 +450,17 @@ def summarize(
445450

446451
# temporary mapping from mode to max_depth
447452
if max_depth is None:
448-
if mode in ModelSummary.MODES:
449-
max_depth = ModelSummary.MODES[mode]
453+
if mode in ModelSummaryMode.supported_types():
454+
max_depth = ModelSummaryMode.get_max_depth(mode)
450455
rank_zero_deprecation(
451456
"Argument `mode` in `LightningModule.summarize` is deprecated in v1.4"
452457
f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behavior."
453458
)
454459
model_summary = ModelSummary(lightning_module, max_depth=max_depth)
455460
elif mode is not None:
456-
raise MisconfigurationException(f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}")
461+
raise MisconfigurationException(
462+
f"`mode` can be None, {', '.join(ModelSummaryMode.supported_types())}, got {mode}"
463+
)
457464
else:
458465
model_summary = ModelSummary(lightning_module, max_depth=max_depth)
459-
log.info("\n" + str(model_summary))
460466
return model_summary

0 commit comments

Comments
 (0)