-
Notifications
You must be signed in to change notification settings - Fork 3.5k
[2/4] Add DeviceStatsMonitor callback #9712
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
kaushikb11
merged 71 commits into
Lightning-AI:master
from
daniellepintz:dstats_callback
Oct 13, 2021
Merged
Changes from all commits
Commits
Show all changes
71 commits
Select commit
Hold shift + click to select a range
c11eb87
Add interface to accelerator to get_device_stats
daniellepintz cba4916
Update changelog
daniellepintz d4252c5
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz fe062b2
add device stats callback
daniellepintz d0e1233
address comments
daniellepintz 269f3ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 4d8cc75
comments
daniellepintz 8e37419
Merge branch 'get_device_stats' of github.com:daniellepintz/pytorch-l…
daniellepintz 6d9cc2e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 351dc5a
wip
daniellepintz 88ada79
Merge branch 'get_device_stats' of github.com:daniellepintz/pytorch-l…
daniellepintz 018e5cd
fix gpu
daniellepintz 310f254
Merge branch 'get_device_stats' of github.com:daniellepintz/pytorch-l…
daniellepintz a0e4bb9
Merge branch 'get_device_stats' of github.com:daniellepintz/pytorch-l…
daniellepintz ec8084d
fix
daniellepintz 5abce11
update docstring
daniellepintz 3936242
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 0fdd368
fix tests
daniellepintz 32f1047
Merge branch 'get_device_stats' of github.com:daniellepintz/pytorch-l…
daniellepintz b0c014e
Merge branch 'get_device_stats' of github.com:daniellepintz/pytorch-l…
daniellepintz e83a362
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz 07dd196
small fix
daniellepintz 4124f15
small fix
daniellepintz 4b81dd9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 9580e43
changelog
daniellepintz 53940c7
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz 9b7cfea
comments
daniellepintz 2b4f8b8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] fd73606
address comments
daniellepintz cf82e1e
Merge branch 'dstats_callback' of github.com:daniellepintz/pytorch-li…
daniellepintz dc40ed3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 7dcb2aa
comments
daniellepintz 849fa02
Update pytorch_lightning/callbacks/device_stats_monitor.py
awaelchli 6289ed5
fix ipu
daniellepintz e6b4b23
Merge branch 'dstats_callback' of github.com:daniellepintz/pytorch-li…
daniellepintz 4f17fdc
fix tpu test
daniellepintz 3ae563c
tpu fix
daniellepintz f92d9a7
tpu test
daniellepintz fa98a5b
Update tests/callbacks/test_device_stats_monitor.py
daniellepintz 5578f0c
pl_module.device
daniellepintz 89e4d75
Merge branch 'dstats_callback' of github.com:daniellepintz/pytorch-li…
daniellepintz 8af0026
fix test
daniellepintz 613a4f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 3bcde5d
fix test
daniellepintz 8f2927a
Merge branch 'dstats_callback' of github.com:daniellepintz/pytorch-li…
daniellepintz bbb8f83
tpu debug
daniellepintz 291e310
tpu debug
daniellepintz 7dec72c
Fix tpu test
kaushikb11 9b09bba
tpu cores
daniellepintz db207f6
Merge branch 'dstats_callback' of github.com:daniellepintz/pytorch-li…
daniellepintz 3b2582b
Remove acceleraors/test_tpu
kaushikb11 12745ca
Merge branch 'dstats_callback' of https://github.com/daniellepintz/py…
kaushikb11 f6b9d9b
Fix tpu test
kaushikb11 cc494f5
Update test
kaushikb11 ad8a9a5
Update jsonnet
kaushikb11 546cdba
Update gpu tests
kaushikb11 bc74ec9
Update pytorch_lightning/accelerators/ipu.py
kaushikb11 1f8d51c
Update pytorch_lightning/callbacks/device_stats_monitor.py
kaushikb11 e5e3f55
Use should_update_logs property from logger_connector
kaushikb11 78b9159
Merge branch 'dstats_callback' of https://github.com/daniellepintz/py…
kaushikb11 d2b4a1d
small updates
daniellepintz 987e22f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 88cfcff
fix signature
daniellepintz b9c4e08
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz ab04998
fix sig
daniellepintz c09b8fb
fix sig
daniellepintz 6b8753e
add prefixes to keys
daniellepintz 5965c07
Update pytorch_lightning/accelerators/gpu.py
kaushikb11 b0c0ceb
Add return type
kaushikb11 55f8065
mypy
daniellepintz 1b4ac1b
Merge branch 'dstats_callback' of github.com:daniellepintz/pytorch-li…
daniellepintz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Device Stats Monitor | ||
==================== | ||
|
||
Monitors and logs device stats during training. | ||
|
||
""" | ||
from typing import Any, Dict, Optional | ||
|
||
import pytorch_lightning as pl | ||
from pytorch_lightning.callbacks.base import Callback | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
from pytorch_lightning.utilities.types import STEP_OUTPUT | ||
|
||
|
||
class DeviceStatsMonitor(Callback): | ||
r""" | ||
Automatically monitors and logs device stats during training stage. ``DeviceStatsMonitor`` | ||
daniellepintz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
is a special callback as it requires a ``logger`` to passed as argument to the ``Trainer``. | ||
|
||
Raises: | ||
ananthsub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
MisconfigurationException: | ||
If ``Trainer`` has no logger. | ||
|
||
Example: | ||
>>> from pytorch_lightning import Trainer | ||
>>> from pytorch_lightning.callbacks import DeviceStatsMonitor | ||
>>> device_stats = DeviceStatsMonitor() # doctest: +SKIP | ||
>>> trainer = Trainer(callbacks=[device_stats]) # doctest: +SKIP | ||
""" | ||
|
||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: | ||
if not trainer.logger: | ||
raise MisconfigurationException("Cannot use DeviceStatsMonitor callback with Trainer that has no logger.") | ||
|
||
def on_train_batch_start( | ||
self, | ||
trainer: "pl.Trainer", | ||
pl_module: "pl.LightningModule", | ||
batch: Any, | ||
batch_idx: int, | ||
unused: Optional[int] = 0, | ||
) -> None: | ||
if not trainer.logger_connector.should_update_logs: | ||
return | ||
|
||
device_stats = trainer.accelerator.get_device_stats(pl_module.device) | ||
kaushikb11 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
prefixed_device_stats = prefix_metrics_keys(device_stats, "on_train_batch_start") | ||
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step) | ||
|
||
def on_train_batch_end( | ||
rohitgr7 marked this conversation as resolved.
Show resolved
Hide resolved
daniellepintz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self, | ||
trainer: "pl.Trainer", | ||
pl_module: "pl.LightningModule", | ||
outputs: STEP_OUTPUT, | ||
batch: Any, | ||
batch_idx: int, | ||
unused: Optional[int] = 0, | ||
) -> None: | ||
if not trainer.logger_connector.should_update_logs: | ||
return | ||
|
||
device_stats = trainer.accelerator.get_device_stats(pl_module.device) | ||
daniellepintz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
prefixed_device_stats = prefix_metrics_keys(device_stats, "on_train_batch_end") | ||
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step) | ||
|
||
|
||
def prefix_metrics_keys(metrics_dict: Dict[str, float], prefix: str) -> Dict[str, float]: | ||
return {prefix + "." + k: v for k, v in metrics_dict.items()} |
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Dict, Optional | ||
|
||
import pytest | ||
|
||
from pytorch_lightning import Trainer | ||
from pytorch_lightning.callbacks import DeviceStatsMonitor | ||
from pytorch_lightning.loggers import CSVLogger | ||
from pytorch_lightning.utilities.distributed import rank_zero_only | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
from tests.helpers import BoringModel | ||
from tests.helpers.runif import RunIf | ||
|
||
|
||
@RunIf(min_torch="1.8") | ||
@RunIf(min_gpus=1) | ||
def test_device_stats_gpu_from_torch(tmpdir): | ||
"""Test GPU stats are logged using a logger with Pytorch >= 1.8.0.""" | ||
model = BoringModel() | ||
device_stats = DeviceStatsMonitor() | ||
|
||
class DebugLogger(CSVLogger): | ||
@rank_zero_only | ||
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: | ||
fields = ["allocated_bytes.all.freed", "inactive_split.all.peak", "reserved_bytes.large_pool.peak"] | ||
for f in fields: | ||
assert any(f in h for h in metrics.keys()) | ||
|
||
trainer = Trainer( | ||
default_root_dir=tmpdir, | ||
max_epochs=2, | ||
limit_train_batches=7, | ||
log_every_n_steps=1, | ||
gpus=1, | ||
callbacks=[device_stats], | ||
logger=DebugLogger(tmpdir), | ||
checkpoint_callback=False, | ||
enable_progress_bar=False, | ||
) | ||
|
||
trainer.fit(model) | ||
|
||
|
||
@RunIf(max_torch="1.7") | ||
@RunIf(min_gpus=1) | ||
def test_device_stats_gpu_from_nvidia(tmpdir): | ||
"""Test GPU stats are logged using a logger with Pytorch < 1.8.0.""" | ||
model = BoringModel() | ||
device_stats = DeviceStatsMonitor() | ||
|
||
class DebugLogger(CSVLogger): | ||
@rank_zero_only | ||
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: | ||
fields = ["utilization.gpu", "memory.used", "memory.free", "utilization.memory"] | ||
for f in fields: | ||
assert any(f in h for h in metrics.keys()) | ||
|
||
trainer = Trainer( | ||
default_root_dir=tmpdir, | ||
max_epochs=2, | ||
limit_train_batches=7, | ||
log_every_n_steps=1, | ||
gpus=1, | ||
callbacks=[device_stats], | ||
logger=DebugLogger(tmpdir), | ||
checkpoint_callback=False, | ||
enable_progress_bar=False, | ||
) | ||
daniellepintz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
trainer.fit(model) | ||
|
||
|
||
@RunIf(tpu=True) | ||
def test_device_stats_monitor_tpu(tmpdir): | ||
"""Test TPU stats are logged using a logger.""" | ||
|
||
model = BoringModel() | ||
device_stats = DeviceStatsMonitor() | ||
|
||
class DebugLogger(CSVLogger): | ||
@rank_zero_only | ||
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: | ||
fields = ["avg. free memory (MB)", "avg. peak memory (MB)"] | ||
for f in fields: | ||
assert any(f in h for h in metrics.keys()) | ||
|
||
trainer = Trainer( | ||
default_root_dir=tmpdir, | ||
max_epochs=1, | ||
limit_train_batches=1, | ||
tpu_cores=8, | ||
daniellepintz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
log_every_n_steps=1, | ||
callbacks=[device_stats], | ||
logger=DebugLogger(tmpdir), | ||
checkpoint_callback=False, | ||
enable_progress_bar=False, | ||
) | ||
|
||
trainer.fit(model) | ||
|
||
|
||
def test_device_stats_monitor_no_logger(tmpdir): | ||
"""Test DeviceStatsMonitor with no logger in Trainer.""" | ||
|
||
model = BoringModel() | ||
device_stats = DeviceStatsMonitor() | ||
|
||
trainer = Trainer( | ||
default_root_dir=tmpdir, | ||
callbacks=[device_stats], | ||
max_epochs=1, | ||
logger=False, | ||
checkpoint_callback=False, | ||
enable_progress_bar=False, | ||
) | ||
|
||
with pytest.raises(MisconfigurationException, match="Trainer that has no logger."): | ||
trainer.fit(model) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.