Skip to content

[2/4] Add DeviceStatsMonitor callback #9712

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 71 commits into from
Oct 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
c11eb87
Add interface to accelerator to get_device_stats
daniellepintz Sep 17, 2021
cba4916
Update changelog
daniellepintz Sep 17, 2021
d4252c5
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Sep 17, 2021
fe062b2
add device stats callback
daniellepintz Sep 17, 2021
d0e1233
address comments
daniellepintz Sep 17, 2021
269f3ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 17, 2021
4d8cc75
comments
daniellepintz Sep 18, 2021
8e37419
Merge branch 'get_device_stats' of github.com:daniellepintz/pytorch-l…
daniellepintz Sep 18, 2021
6d9cc2e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2021
351dc5a
wip
daniellepintz Sep 18, 2021
88ada79
Merge branch 'get_device_stats' of github.com:daniellepintz/pytorch-l…
daniellepintz Sep 18, 2021
018e5cd
fix gpu
daniellepintz Sep 18, 2021
310f254
Merge branch 'get_device_stats' of github.com:daniellepintz/pytorch-l…
daniellepintz Sep 18, 2021
a0e4bb9
Merge branch 'get_device_stats' of github.com:daniellepintz/pytorch-l…
daniellepintz Sep 18, 2021
ec8084d
fix
daniellepintz Sep 18, 2021
5abce11
update docstring
daniellepintz Sep 18, 2021
3936242
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2021
0fdd368
fix tests
daniellepintz Sep 18, 2021
32f1047
Merge branch 'get_device_stats' of github.com:daniellepintz/pytorch-l…
daniellepintz Sep 18, 2021
b0c014e
Merge branch 'get_device_stats' of github.com:daniellepintz/pytorch-l…
daniellepintz Sep 18, 2021
e83a362
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Sep 27, 2021
07dd196
small fix
daniellepintz Sep 27, 2021
4124f15
small fix
daniellepintz Sep 27, 2021
4b81dd9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2021
9580e43
changelog
daniellepintz Sep 27, 2021
53940c7
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Sep 27, 2021
9b7cfea
comments
daniellepintz Sep 28, 2021
2b4f8b8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 28, 2021
fd73606
address comments
daniellepintz Sep 29, 2021
cf82e1e
Merge branch 'dstats_callback' of github.com:daniellepintz/pytorch-li…
daniellepintz Sep 29, 2021
dc40ed3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2021
7dcb2aa
comments
daniellepintz Sep 29, 2021
849fa02
Update pytorch_lightning/callbacks/device_stats_monitor.py
awaelchli Sep 29, 2021
6289ed5
fix ipu
daniellepintz Sep 29, 2021
e6b4b23
Merge branch 'dstats_callback' of github.com:daniellepintz/pytorch-li…
daniellepintz Sep 30, 2021
4f17fdc
fix tpu test
daniellepintz Sep 30, 2021
3ae563c
tpu fix
daniellepintz Sep 30, 2021
f92d9a7
tpu test
daniellepintz Oct 1, 2021
fa98a5b
Update tests/callbacks/test_device_stats_monitor.py
daniellepintz Oct 6, 2021
5578f0c
pl_module.device
daniellepintz Oct 8, 2021
89e4d75
Merge branch 'dstats_callback' of github.com:daniellepintz/pytorch-li…
daniellepintz Oct 8, 2021
8af0026
fix test
daniellepintz Oct 8, 2021
613a4f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2021
3bcde5d
fix test
daniellepintz Oct 8, 2021
8f2927a
Merge branch 'dstats_callback' of github.com:daniellepintz/pytorch-li…
daniellepintz Oct 8, 2021
bbb8f83
tpu debug
daniellepintz Oct 8, 2021
291e310
tpu debug
daniellepintz Oct 8, 2021
7dec72c
Fix tpu test
kaushikb11 Oct 11, 2021
9b09bba
tpu cores
daniellepintz Oct 11, 2021
db207f6
Merge branch 'dstats_callback' of github.com:daniellepintz/pytorch-li…
daniellepintz Oct 11, 2021
3b2582b
Remove acceleraors/test_tpu
kaushikb11 Oct 12, 2021
12745ca
Merge branch 'dstats_callback' of https://github.com/daniellepintz/py…
kaushikb11 Oct 12, 2021
f6b9d9b
Fix tpu test
kaushikb11 Oct 12, 2021
cc494f5
Update test
kaushikb11 Oct 12, 2021
ad8a9a5
Update jsonnet
kaushikb11 Oct 12, 2021
546cdba
Update gpu tests
kaushikb11 Oct 12, 2021
bc74ec9
Update pytorch_lightning/accelerators/ipu.py
kaushikb11 Oct 12, 2021
1f8d51c
Update pytorch_lightning/callbacks/device_stats_monitor.py
kaushikb11 Oct 12, 2021
e5e3f55
Use should_update_logs property from logger_connector
kaushikb11 Oct 12, 2021
78b9159
Merge branch 'dstats_callback' of https://github.com/daniellepintz/py…
kaushikb11 Oct 12, 2021
d2b4a1d
small updates
daniellepintz Oct 12, 2021
987e22f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 12, 2021
88cfcff
fix signature
daniellepintz Oct 12, 2021
b9c4e08
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Oct 12, 2021
ab04998
fix sig
daniellepintz Oct 12, 2021
c09b8fb
fix sig
daniellepintz Oct 12, 2021
6b8753e
add prefixes to keys
daniellepintz Oct 13, 2021
5965c07
Update pytorch_lightning/accelerators/gpu.py
kaushikb11 Oct 13, 2021
b0c0ceb
Add return type
kaushikb11 Oct 13, 2021
55f8065
mypy
daniellepintz Oct 13, 2021
1b4ac1b
Merge branch 'dstats_callback' of github.com:daniellepintz/pytorch-li…
daniellepintz Oct 13, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a warning when an unknown key is encountered in optimizer configuration, and when `OneCycleLR` is used with `"interval": "epoch"` ([#9666](https://github.com/PyTorchLightning/pytorch-lightning/pull/9666))


- Added `DeviceStatsMonitor` callback ([#9712](https://github.com/PyTorchLightning/pytorch-lightning/pull/9712))


- Added `enable_progress_bar` to Trainer constructor ([#9664](https://github.com/PyTorchLightning/pytorch-lightning/pull/9664))


Expand Down
1 change: 1 addition & 0 deletions dockers/tpu-tests/tpu_test_cases.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ local tputests = base.BaseTest {
tests/profiler/test_xla_profiler.py \
pytorch_lightning/utilities/xla_device.py \
tests/accelerators/test_tpu_backend.py \
tests/callbacks/test_device_stats_monitor.py \
tests/models/test_tpu.py
test_exit_code=$?
echo "\n||| END PYTEST LOGS |||\n"
Expand Down
1 change: 1 addition & 0 deletions docs/source/extensions/accelerators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Currently there are accelerators for:
- CPU
- GPU
- TPU
- IPU

Each Accelerator gets two plugins upon initialization:
One to handle differences from the training routine and one to handle different precisions.
Expand Down
1 change: 1 addition & 0 deletions docs/source/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ Lightning has a few built-in callbacks.
BaseFinetuning
BasePredictionWriter
Callback
DeviceStatsMonitor
EarlyStopping
GPUStatsMonitor
GradientAccumulationScheduler
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ ignore_errors = "True"

[[tool.mypy.overrides]]
module = [
"pytorch_lightning.callbacks.device_stats_monitor",
"pytorch_lightning.callbacks.model_summary",
"pytorch_lightning.callbacks.pruning",
"pytorch_lightning.callbacks.rich_model_summary",
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ def setup(self, trainer: "pl.Trainer") -> None:
return super().setup(trainer)

def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""Returns dummy implementation for now."""
"""CPU device stats aren't supported yet."""
return {}
7 changes: 6 additions & 1 deletion pytorch_lightning/accelerators/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable
from typing import Any, Callable, Dict, Union

import torch
from torch.optim import Optimizer

import pytorch_lightning as pl
Expand All @@ -37,3 +38,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs: Any) -> None:
# Optimizer step is handled by the IPU accelerator.
lambda_closure()

def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""IPU device stats aren't supported yet."""
return {}
2 changes: 2 additions & 0 deletions pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.callbacks.device_stats_monitor import DeviceStatsMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.finetuning import BackboneFinetuning, BaseFinetuning
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
Expand All @@ -33,6 +34,7 @@
"BackboneFinetuning",
"BaseFinetuning",
"Callback",
"DeviceStatsMonitor",
"EarlyStopping",
"GPUStatsMonitor",
"XLAStatsMonitor",
Expand Down
82 changes: 82 additions & 0 deletions pytorch_lightning/callbacks/device_stats_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Device Stats Monitor
====================

Monitors and logs device stats during training.

"""
from typing import Any, Dict, Optional

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import STEP_OUTPUT


class DeviceStatsMonitor(Callback):
r"""
Automatically monitors and logs device stats during training stage. ``DeviceStatsMonitor``
is a special callback as it requires a ``logger`` to passed as argument to the ``Trainer``.

Raises:
MisconfigurationException:
If ``Trainer`` has no logger.

Example:
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import DeviceStatsMonitor
>>> device_stats = DeviceStatsMonitor() # doctest: +SKIP
>>> trainer = Trainer(callbacks=[device_stats]) # doctest: +SKIP
"""

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
if not trainer.logger:
raise MisconfigurationException("Cannot use DeviceStatsMonitor callback with Trainer that has no logger.")

def on_train_batch_start(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
unused: Optional[int] = 0,
) -> None:
if not trainer.logger_connector.should_update_logs:
return

device_stats = trainer.accelerator.get_device_stats(pl_module.device)
prefixed_device_stats = prefix_metrics_keys(device_stats, "on_train_batch_start")
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)

def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
unused: Optional[int] = 0,
) -> None:
if not trainer.logger_connector.should_update_logs:
return

device_stats = trainer.accelerator.get_device_stats(pl_module.device)
prefixed_device_stats = prefix_metrics_keys(device_stats, "on_train_batch_end")
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)


def prefix_metrics_keys(metrics_dict: Dict[str, float], prefix: str) -> Dict[str, float]:
return {prefix + "." + k: v for k, v in metrics_dict.items()}
16 changes: 0 additions & 16 deletions tests/accelerators/test_tpu.py

This file was deleted.

130 changes: 130 additions & 0 deletions tests/callbacks/test_device_stats_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional

import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import DeviceStatsMonitor
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf


@RunIf(min_torch="1.8")
@RunIf(min_gpus=1)
def test_device_stats_gpu_from_torch(tmpdir):
"""Test GPU stats are logged using a logger with Pytorch >= 1.8.0."""
model = BoringModel()
device_stats = DeviceStatsMonitor()

class DebugLogger(CSVLogger):
@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
fields = ["allocated_bytes.all.freed", "inactive_split.all.peak", "reserved_bytes.large_pool.peak"]
for f in fields:
assert any(f in h for h in metrics.keys())

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
limit_train_batches=7,
log_every_n_steps=1,
gpus=1,
callbacks=[device_stats],
logger=DebugLogger(tmpdir),
checkpoint_callback=False,
enable_progress_bar=False,
)

trainer.fit(model)


@RunIf(max_torch="1.7")
@RunIf(min_gpus=1)
def test_device_stats_gpu_from_nvidia(tmpdir):
"""Test GPU stats are logged using a logger with Pytorch < 1.8.0."""
model = BoringModel()
device_stats = DeviceStatsMonitor()

class DebugLogger(CSVLogger):
@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
fields = ["utilization.gpu", "memory.used", "memory.free", "utilization.memory"]
for f in fields:
assert any(f in h for h in metrics.keys())

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
limit_train_batches=7,
log_every_n_steps=1,
gpus=1,
callbacks=[device_stats],
logger=DebugLogger(tmpdir),
checkpoint_callback=False,
enable_progress_bar=False,
)

trainer.fit(model)


@RunIf(tpu=True)
def test_device_stats_monitor_tpu(tmpdir):
"""Test TPU stats are logged using a logger."""

model = BoringModel()
device_stats = DeviceStatsMonitor()

class DebugLogger(CSVLogger):
@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
fields = ["avg. free memory (MB)", "avg. peak memory (MB)"]
for f in fields:
assert any(f in h for h in metrics.keys())

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=1,
tpu_cores=8,
log_every_n_steps=1,
callbacks=[device_stats],
logger=DebugLogger(tmpdir),
checkpoint_callback=False,
enable_progress_bar=False,
)

trainer.fit(model)


def test_device_stats_monitor_no_logger(tmpdir):
"""Test DeviceStatsMonitor with no logger in Trainer."""

model = BoringModel()
device_stats = DeviceStatsMonitor()

trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[device_stats],
max_epochs=1,
logger=False,
checkpoint_callback=False,
enable_progress_bar=False,
)

with pytest.raises(MisconfigurationException, match="Trainer that has no logger."):
trainer.fit(model)