Skip to content

Commit 940b910

Browse files
daniellepintzananthsubtchatonpre-commit-ci[bot]awaelchli
authored
[2/4] Add DeviceStatsMonitor callback (#9712)
Co-authored-by: ananthsub <[email protected]> Co-authored-by: thomas chaton <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Kaushik B <[email protected]> Co-authored-by: Kaushik B <[email protected]>
1 parent 23e8b59 commit 940b910

File tree

11 files changed

+228
-18
lines changed

11 files changed

+228
-18
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
163163
- Added a warning when an unknown key is encountered in optimizer configuration, and when `OneCycleLR` is used with `"interval": "epoch"` ([#9666](https://github.com/PyTorchLightning/pytorch-lightning/pull/9666))
164164

165165

166+
- Added `DeviceStatsMonitor` callback ([#9712](https://github.com/PyTorchLightning/pytorch-lightning/pull/9712))
167+
168+
166169
- Added `enable_progress_bar` to Trainer constructor ([#9664](https://github.com/PyTorchLightning/pytorch-lightning/pull/9664))
167170

168171

dockers/tpu-tests/tpu_test_cases.jsonnet

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ local tputests = base.BaseTest {
3636
tests/profiler/test_xla_profiler.py \
3737
pytorch_lightning/utilities/xla_device.py \
3838
tests/accelerators/test_tpu_backend.py \
39+
tests/callbacks/test_device_stats_monitor.py \
3940
tests/models/test_tpu.py
4041
test_exit_code=$?
4142
echo "\n||| END PYTEST LOGS |||\n"

docs/source/extensions/accelerators.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Currently there are accelerators for:
1414
- CPU
1515
- GPU
1616
- TPU
17+
- IPU
1718

1819
Each Accelerator gets two plugins upon initialization:
1920
One to handle differences from the training routine and one to handle different precisions.

docs/source/extensions/callbacks.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ Lightning has a few built-in callbacks.
9999
BaseFinetuning
100100
BasePredictionWriter
101101
Callback
102+
DeviceStatsMonitor
102103
EarlyStopping
103104
GPUStatsMonitor
104105
GradientAccumulationScheduler

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ ignore_errors = "True"
6161

6262
[[tool.mypy.overrides]]
6363
module = [
64+
"pytorch_lightning.callbacks.device_stats_monitor",
6465
"pytorch_lightning.callbacks.model_summary",
6566
"pytorch_lightning.callbacks.pruning",
6667
"pytorch_lightning.callbacks.rich_model_summary",

pytorch_lightning/accelerators/cpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,5 @@ def setup(self, trainer: "pl.Trainer") -> None:
3535
return super().setup(trainer)
3636

3737
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
38-
"""Returns dummy implementation for now."""
38+
"""CPU device stats aren't supported yet."""
3939
return {}

pytorch_lightning/accelerators/ipu.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Callable
14+
from typing import Any, Callable, Dict, Union
1515

16+
import torch
1617
from torch.optim import Optimizer
1718

1819
import pytorch_lightning as pl
@@ -37,3 +38,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
3738
def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs: Any) -> None:
3839
# Optimizer step is handled by the IPU accelerator.
3940
lambda_closure()
41+
42+
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
43+
"""IPU device stats aren't supported yet."""
44+
return {}

pytorch_lightning/callbacks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from pytorch_lightning.callbacks.base import Callback
15+
from pytorch_lightning.callbacks.device_stats_monitor import DeviceStatsMonitor
1516
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
1617
from pytorch_lightning.callbacks.finetuning import BackboneFinetuning, BaseFinetuning
1718
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
@@ -33,6 +34,7 @@
3334
"BackboneFinetuning",
3435
"BaseFinetuning",
3536
"Callback",
37+
"DeviceStatsMonitor",
3638
"EarlyStopping",
3739
"GPUStatsMonitor",
3840
"XLAStatsMonitor",
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Device Stats Monitor
16+
====================
17+
18+
Monitors and logs device stats during training.
19+
20+
"""
21+
from typing import Any, Dict, Optional
22+
23+
import pytorch_lightning as pl
24+
from pytorch_lightning.callbacks.base import Callback
25+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
26+
from pytorch_lightning.utilities.types import STEP_OUTPUT
27+
28+
29+
class DeviceStatsMonitor(Callback):
30+
r"""
31+
Automatically monitors and logs device stats during training stage. ``DeviceStatsMonitor``
32+
is a special callback as it requires a ``logger`` to passed as argument to the ``Trainer``.
33+
34+
Raises:
35+
MisconfigurationException:
36+
If ``Trainer`` has no logger.
37+
38+
Example:
39+
>>> from pytorch_lightning import Trainer
40+
>>> from pytorch_lightning.callbacks import DeviceStatsMonitor
41+
>>> device_stats = DeviceStatsMonitor() # doctest: +SKIP
42+
>>> trainer = Trainer(callbacks=[device_stats]) # doctest: +SKIP
43+
"""
44+
45+
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
46+
if not trainer.logger:
47+
raise MisconfigurationException("Cannot use DeviceStatsMonitor callback with Trainer that has no logger.")
48+
49+
def on_train_batch_start(
50+
self,
51+
trainer: "pl.Trainer",
52+
pl_module: "pl.LightningModule",
53+
batch: Any,
54+
batch_idx: int,
55+
unused: Optional[int] = 0,
56+
) -> None:
57+
if not trainer.logger_connector.should_update_logs:
58+
return
59+
60+
device_stats = trainer.accelerator.get_device_stats(pl_module.device)
61+
prefixed_device_stats = prefix_metrics_keys(device_stats, "on_train_batch_start")
62+
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
63+
64+
def on_train_batch_end(
65+
self,
66+
trainer: "pl.Trainer",
67+
pl_module: "pl.LightningModule",
68+
outputs: STEP_OUTPUT,
69+
batch: Any,
70+
batch_idx: int,
71+
unused: Optional[int] = 0,
72+
) -> None:
73+
if not trainer.logger_connector.should_update_logs:
74+
return
75+
76+
device_stats = trainer.accelerator.get_device_stats(pl_module.device)
77+
prefixed_device_stats = prefix_metrics_keys(device_stats, "on_train_batch_end")
78+
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
79+
80+
81+
def prefix_metrics_keys(metrics_dict: Dict[str, float], prefix: str) -> Dict[str, float]:
82+
return {prefix + "." + k: v for k, v in metrics_dict.items()}

tests/accelerators/test_tpu.py

Lines changed: 0 additions & 16 deletions
This file was deleted.
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Dict, Optional
15+
16+
import pytest
17+
18+
from pytorch_lightning import Trainer
19+
from pytorch_lightning.callbacks import DeviceStatsMonitor
20+
from pytorch_lightning.loggers import CSVLogger
21+
from pytorch_lightning.utilities.distributed import rank_zero_only
22+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
23+
from tests.helpers import BoringModel
24+
from tests.helpers.runif import RunIf
25+
26+
27+
@RunIf(min_torch="1.8")
28+
@RunIf(min_gpus=1)
29+
def test_device_stats_gpu_from_torch(tmpdir):
30+
"""Test GPU stats are logged using a logger with Pytorch >= 1.8.0."""
31+
model = BoringModel()
32+
device_stats = DeviceStatsMonitor()
33+
34+
class DebugLogger(CSVLogger):
35+
@rank_zero_only
36+
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
37+
fields = ["allocated_bytes.all.freed", "inactive_split.all.peak", "reserved_bytes.large_pool.peak"]
38+
for f in fields:
39+
assert any(f in h for h in metrics.keys())
40+
41+
trainer = Trainer(
42+
default_root_dir=tmpdir,
43+
max_epochs=2,
44+
limit_train_batches=7,
45+
log_every_n_steps=1,
46+
gpus=1,
47+
callbacks=[device_stats],
48+
logger=DebugLogger(tmpdir),
49+
checkpoint_callback=False,
50+
enable_progress_bar=False,
51+
)
52+
53+
trainer.fit(model)
54+
55+
56+
@RunIf(max_torch="1.7")
57+
@RunIf(min_gpus=1)
58+
def test_device_stats_gpu_from_nvidia(tmpdir):
59+
"""Test GPU stats are logged using a logger with Pytorch < 1.8.0."""
60+
model = BoringModel()
61+
device_stats = DeviceStatsMonitor()
62+
63+
class DebugLogger(CSVLogger):
64+
@rank_zero_only
65+
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
66+
fields = ["utilization.gpu", "memory.used", "memory.free", "utilization.memory"]
67+
for f in fields:
68+
assert any(f in h for h in metrics.keys())
69+
70+
trainer = Trainer(
71+
default_root_dir=tmpdir,
72+
max_epochs=2,
73+
limit_train_batches=7,
74+
log_every_n_steps=1,
75+
gpus=1,
76+
callbacks=[device_stats],
77+
logger=DebugLogger(tmpdir),
78+
checkpoint_callback=False,
79+
enable_progress_bar=False,
80+
)
81+
82+
trainer.fit(model)
83+
84+
85+
@RunIf(tpu=True)
86+
def test_device_stats_monitor_tpu(tmpdir):
87+
"""Test TPU stats are logged using a logger."""
88+
89+
model = BoringModel()
90+
device_stats = DeviceStatsMonitor()
91+
92+
class DebugLogger(CSVLogger):
93+
@rank_zero_only
94+
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
95+
fields = ["avg. free memory (MB)", "avg. peak memory (MB)"]
96+
for f in fields:
97+
assert any(f in h for h in metrics.keys())
98+
99+
trainer = Trainer(
100+
default_root_dir=tmpdir,
101+
max_epochs=1,
102+
limit_train_batches=1,
103+
tpu_cores=8,
104+
log_every_n_steps=1,
105+
callbacks=[device_stats],
106+
logger=DebugLogger(tmpdir),
107+
checkpoint_callback=False,
108+
enable_progress_bar=False,
109+
)
110+
111+
trainer.fit(model)
112+
113+
114+
def test_device_stats_monitor_no_logger(tmpdir):
115+
"""Test DeviceStatsMonitor with no logger in Trainer."""
116+
117+
model = BoringModel()
118+
device_stats = DeviceStatsMonitor()
119+
120+
trainer = Trainer(
121+
default_root_dir=tmpdir,
122+
callbacks=[device_stats],
123+
max_epochs=1,
124+
logger=False,
125+
checkpoint_callback=False,
126+
enable_progress_bar=False,
127+
)
128+
129+
with pytest.raises(MisconfigurationException, match="Trainer that has no logger."):
130+
trainer.fit(model)

0 commit comments

Comments
 (0)