Skip to content

Commit d46831e

Browse files
committed
Merge branch 'master' into refactor/loop-restructuring
2 parents 239b8ad + e0f2e04 commit d46831e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+552
-394
lines changed

.azure-pipelines/gpu-benchmark.yml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
# Python package
2+
# Create and test a Python package on multiple Python versions.
3+
# Add steps that analyze code, save the dist with the build record, publish to a PyPI-compatible index, and more:
4+
# https://docs.microsoft.com/azure/devops/pipelines/languages/python
5+
6+
trigger:
7+
tags:
8+
include:
9+
- '*'
10+
branches:
11+
include:
12+
- "master"
13+
- "release/*"
14+
- "refs/tags/*"
15+
16+
pr: none
17+
118
schedules:
219
- cron: "0 0 * * *" # At the end of every day
320
displayName: Daily midnight benchmark

.github/CODEOWNERS

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
# Packages
2222
/pytorch_lightning/accelerators @williamfalcon @tchaton @SeanNaren @awaelchli @justusschock @kaushikb11
2323
/pytorch_lightning/callbacks @williamfalcon @tchaton @carmocca @borda @kaushikb11
24-
/pytorch_lightning/cluster_environments @borda @tchaton @SeanNaren @carmocca @kaushikb11
2524
/pytorch_lightning/core @tchaton @SeanNaren @borda @carmocca @justusschock @kaushikb11
2625
/pytorch_lightning/distributed @williamfalcon @tchaton @awaelchli @kaushikb11
2726
/pytorch_lightning/loggers @tchaton @awaelchli @borda
27+
/pytorch_lightning/loggers/wandb.py @borisdayma
2828
/pytorch_lightning/loops @tchaton @awaelchli @justusschock @carmocca
2929
/pytorch_lightning/overrides @tchaton @SeanNaren @borda
3030
/pytorch_lightning/plugins @tchaton @SeanNaren @awaelchli @justusschock
@@ -38,11 +38,6 @@
3838
/pytorch_lightning/trainer/connectors/logger_connector @tchaton @carmocca
3939
/pytorch_lightning/trainer/progress.py @tchaton @awaelchli @carmocca
4040

41-
# Metrics
42-
/pytorch_lightning/metrics/ @SkafteNicki @ananyahjha93 @justusschock
43-
/tests/metrics/ @SkafteNicki @ananyahjha93 @justusschock
44-
/docs/source/metrics.rst @SkafteNicki @ananyahjha93 @justusschock
45-
4641
# API
4742
/pytorch_lightning/callbacks/base.py @williamfalcon @awaelchli @ananthsub @carmocca
4843
/pytorch_lightning/core/datamodule.py @williamFalcon @awaelchli @ananthsub @carmocca

CHANGELOG.md

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5959
* Added Fault Tolerant Training to `DataFetcher` ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
6060
* Replaced old prefetch iterator with new `DataFetcher` in training loop ([#8953](https://github.com/PyTorchLightning/pytorch-lightning/pull/8953))
6161
* Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950))
62+
* Converted state to tuple explicitly when setting Python random state ([#9401](https://github.com/PyTorchLightning/pytorch-lightning/pull/9401))
63+
6264

6365
- Checkpoint saving & loading extensibility:
6466
* Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743))
@@ -108,10 +110,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
108110
- Added `on_exception` callback hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183))
109111

110112

111-
- Add a warning to deepspeed when inferring batch size ([#9221](https://github.com/PyTorchLightning/pytorch-lightning/pull/9221))
113+
- Added a warning to deepspeed when inferring batch size ([#9221](https://github.com/PyTorchLightning/pytorch-lightning/pull/9221))
114+
115+
116+
- Added `inference_mode` for evaluation and prediction ([#8813](https://github.com/PyTorchLightning/pytorch-lightning/pull/8813))
112117

113118

114-
- Added `inference_mode` for evaluation and prediction ([8813](https://github.com/PyTorchLightning/pytorch-lightning/pull/8813))
119+
- Added `remove_checkpoint` to `CheckpointIO` plugin by moving the responsibility from `ModelCheckpoint` Callback ([#9373](https://github.com/PyTorchLightning/pytorch-lightning/pull/9373))
115120

116121

117122
### Changed
@@ -177,6 +182,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
177182
- Deprecated `DataModule` properties: `train_transforms`, `val_transforms`, `test_transforms`, `size`, `dims` ([#8851](https://github.com/PyTorchLightning/pytorch-lightning/pull/8851))
178183

179184

185+
- Deprecated `LightningModule.get_progress_bar_dict` and `Trainer.progress_bar_dict` in favor of `pytorch_lightning.callbacks.progress.base.get_standard_metrics` and `ProgressBarBase.get_metrics` ([#8985](https://github.com/PyTorchLightning/pytorch-lightning/pull/8985))
186+
187+
180188
- Deprecated `prepare_data_per_node` flag on Trainer and set it as a property of `DataHooks`, accessible in the `LightningModule` and `LightningDataModule` ([#8958](https://github.com/PyTorchLightning/pytorch-lightning/pull/8958))
181189

182190

@@ -323,9 +331,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
323331
- Fixed incorrect main progress bar indicator when resuming training mid-epoch ([#9310](https://github.com/PyTorchLightning/pytorch-lightning/pull/9310))
324332

325333

334+
- Fixed logging of nan parameters ([#9364](https://github.com/PyTorchLightning/pytorch-lightning/pull/9364))
335+
336+
326337
- Fixed `replace_sampler` missing the batch size under specific conditions ([#9367](https://github.com/PyTorchLightning/pytorch-lightning/pull/9367))
327338

328339

340+
- Fixed bug where the training step output needed to be `deepcopy`-ed ([#9349](https://github.com/PyTorchLightning/pytorch-lightning/pull/9349))
341+
342+
343+
- Fixed freeing data iterators in loop `on_run_end` ([#9386](https://github.com/PyTorchLightning/pytorch-lightning/pull/9386))
344+
345+
329346
## [1.4.5] - 2021-08-31
330347

331348
- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))

CITATION.cff

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ title: "PyTorch Lightning"
44
abstract: "The lightweight PyTorch wrapper for high-performance AI research. Scale your models, not the boilerplate."
55
date-released: 2019-03-30
66
authors:
7-
- family-names: "William"
8-
given-names: "Falcon"
7+
- family-names: "Falcon"
8+
given-names: "William"
99
- name: "The PyTorch Lightning team"
1010
version: 1.4
1111
doi: 10.5281/zenodo.3828935

docs/source/common/lightning_module.rst

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,12 +1242,6 @@ backward
12421242
.. automethod:: pytorch_lightning.core.lightning.LightningModule.backward
12431243
:noindex:
12441244

1245-
get_progress_bar_dict
1246-
~~~~~~~~~~~~~~~~~~~~~
1247-
1248-
.. automethod:: pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict
1249-
:noindex:
1250-
12511245
on_before_backward
12521246
~~~~~~~~~~~~~~~~~~
12531247

docs/source/extensions/logging.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,13 +245,13 @@ Modifying the progress bar
245245

246246
The progress bar by default already includes the training loss and version number of the experiment
247247
if you are using a logger. These defaults can be customized by overriding the
248-
:func:`~pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict` hook in your module.
248+
:func:`~pytorch_lightning.callbacks.base.ProgressBarBase.get_metrics` hook in your module.
249249

250250
.. code-block:: python
251251
252-
def get_progress_bar_dict(self):
252+
def get_metrics(self):
253253
# don't show the version number
254-
items = super().get_progress_bar_dict()
254+
items = super().get_metrics()
255255
items.pop("v_num", None)
256256
return items
257257

pytorch_lightning/accelerators/accelerator.py

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,7 @@ def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dat
173173
def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
174174
"""The actual training step.
175175
176-
Args:
177-
step_kwargs: the arguments for the models training step. Can consist of the following:
178-
179-
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
180-
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
181-
- batch_idx (int): Integer displaying index of this batch
182-
- optimizer_idx (int): When using multiple optimizers, this argument will also be present.
183-
- hiddens(:class:`~torch.Tensor`): Passed in if
184-
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
176+
See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details
185177
"""
186178
with self.precision_plugin.train_step_context():
187179
return self.training_type_plugin.training_step(*step_kwargs.values())
@@ -192,44 +184,23 @@ def post_training_step(self) -> None:
192184
def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]:
193185
"""The actual validation step.
194186
195-
Args:
196-
step_kwargs: the arguments for the models validation step. Can consist of the following:
197-
198-
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
199-
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
200-
- batch_idx (int): The index of this batch
201-
- dataloader_idx (int): The index of the dataloader that produced this batch
202-
(only if multiple val dataloaders used)
187+
See :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` for more details
203188
"""
204189
with self.precision_plugin.val_step_context():
205190
return self.training_type_plugin.validation_step(*step_kwargs.values())
206191

207192
def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]:
208193
"""The actual test step.
209194
210-
Args:
211-
step_kwargs: the arguments for the models test step. Can consist of the following:
212-
213-
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
214-
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
215-
- batch_idx (int): The index of this batch.
216-
- dataloader_idx (int): The index of the dataloader that produced this batch
217-
(only if multiple test dataloaders used).
195+
See :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` for more details
218196
"""
219197
with self.precision_plugin.test_step_context():
220198
return self.training_type_plugin.test_step(*step_kwargs.values())
221199

222200
def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
223201
"""The actual predict step.
224202
225-
Args:
226-
step_kwargs: the arguments for the models predict step. Can consist of the following:
227-
228-
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
229-
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
230-
- batch_idx (int): The index of this batch.
231-
- dataloader_idx (int): The index of the dataloader that produced this batch
232-
(only if multiple predict dataloaders used).
203+
See :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` for more details
233204
"""
234205
with self.precision_plugin.predict_step_context():
235206
return self.training_type_plugin.predict_step(*step_kwargs.values())

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -486,19 +486,6 @@ def __init_triggers(
486486
def every_n_epochs(self) -> Optional[int]:
487487
return self._every_n_epochs
488488

489-
def _del_model(self, trainer: "pl.Trainer", filepath: str) -> None:
490-
if trainer.should_rank_save_checkpoint and self._fs.exists(filepath):
491-
self._fs.rm(filepath, recursive=True)
492-
log.debug(f"Removed checkpoint: {filepath}")
493-
494-
def _save_model(self, trainer: "pl.Trainer", filepath: str) -> None:
495-
# make paths
496-
if trainer.should_rank_save_checkpoint:
497-
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)
498-
499-
# delegate the saving to the trainer
500-
trainer.save_checkpoint(filepath, self.save_weights_only)
501-
502489
def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[torch.Tensor] = None) -> bool:
503490
if current is None:
504491
return False
@@ -671,10 +658,10 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[
671658
filepath = self._format_checkpoint_name(self.CHECKPOINT_NAME_LAST, monitor_candidates)
672659
filepath = os.path.join(self.dirpath, f"{filepath}{self.FILE_EXTENSION}")
673660

674-
self._save_model(trainer, filepath)
661+
trainer.save_checkpoint(filepath, self.save_weights_only)
675662

676-
if self.last_model_path and self.last_model_path != filepath and trainer.should_rank_save_checkpoint:
677-
self._del_model(trainer, self.last_model_path)
663+
if self.last_model_path and self.last_model_path != filepath:
664+
trainer.training_type_plugin.remove_checkpoint(self.last_model_path)
678665

679666
self.last_model_path = filepath
680667

@@ -696,15 +683,10 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate
696683
return
697684

698685
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
699-
self._save_model(trainer, filepath)
686+
trainer.save_checkpoint(filepath, self.save_weights_only)
700687

701-
if (
702-
self.save_top_k == 1
703-
and self.best_model_path
704-
and self.best_model_path != filepath
705-
and trainer.should_rank_save_checkpoint
706-
):
707-
self._del_model(trainer, self.best_model_path)
688+
if self.save_top_k == 1 and self.best_model_path and self.best_model_path != filepath:
689+
trainer.training_type_plugin.remove_checkpoint(self.best_model_path)
708690

709691
self.best_model_path = filepath
710692

@@ -748,10 +730,10 @@ def _update_best_and_save(
748730
f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}"
749731
f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}'
750732
)
751-
self._save_model(trainer, filepath)
733+
trainer.save_checkpoint(filepath, self.save_weights_only)
752734

753735
if del_filepath is not None and filepath != del_filepath:
754-
self._del_model(trainer, del_filepath)
736+
trainer.training_type_plugin.remove_checkpoint(del_filepath)
755737

756738
def to_yaml(self, filepath: Optional[Union[str, Path]] = None) -> None:
757739
"""Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML

pytorch_lightning/callbacks/progress/base.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from typing import Dict, Union
15+
16+
import pytorch_lightning as pl
1417
from pytorch_lightning.callbacks import Callback
18+
from pytorch_lightning.utilities import rank_zero_warn
1519

1620

1721
class ProgressBarBase(Callback):
@@ -177,3 +181,70 @@ def on_predict_epoch_start(self, trainer, pl_module):
177181

178182
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
179183
self._predict_batch_idx += 1
184+
185+
def get_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str]]:
186+
r"""
187+
Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics.
188+
Implement this to override the items displayed in the progress bar.
189+
190+
Here is an example of how to override the defaults:
191+
192+
.. code-block:: python
193+
194+
def get_metrics(self, trainer, model):
195+
# don't show the version number
196+
items = super().get_metrics(trainer, model)
197+
items.pop("v_num", None)
198+
return items
199+
200+
Return:
201+
Dictionary with the items to be displayed in the progress bar.
202+
"""
203+
standard_metrics = pl_module.get_progress_bar_dict()
204+
pbar_metrics = trainer.progress_bar_metrics
205+
duplicates = list(standard_metrics.keys() & pbar_metrics.keys())
206+
if duplicates:
207+
rank_zero_warn(
208+
f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and"
209+
f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. "
210+
" If this is undesired, change the name or override `get_metrics()` in the progress bar callback.",
211+
UserWarning,
212+
)
213+
214+
return {**standard_metrics, **pbar_metrics}
215+
216+
217+
def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str]]:
218+
r"""
219+
Returns several standard metrics displayed in the progress bar, including the average loss value,
220+
split index of BPTT (if used) and the version of the experiment when using a logger.
221+
222+
.. code-block::
223+
224+
Epoch 1: 4%|▎ | 40/1095 [00:03<01:37, 10.84it/s, loss=4.501, v_num=10]
225+
226+
Return:
227+
Dictionary with the standard metrics to be displayed in the progress bar.
228+
"""
229+
# call .item() only once but store elements without graphs
230+
running_train_loss = trainer.fit_loop.running_loss.mean()
231+
avg_training_loss = None
232+
if running_train_loss is not None:
233+
avg_training_loss = running_train_loss.cpu().item()
234+
elif pl_module.automatic_optimization:
235+
avg_training_loss = float("NaN")
236+
237+
items_dict = {}
238+
if avg_training_loss is not None:
239+
items_dict["loss"] = f"{avg_training_loss:.3g}"
240+
241+
if pl_module.truncated_bptt_steps > 0:
242+
items_dict["split_idx"] = trainer.fit_loop.split_idx
243+
244+
if trainer.logger is not None and trainer.logger.version is not None:
245+
version = trainer.logger.version
246+
# show last 4 places of long version strings
247+
version = version[-4:] if isinstance(version, str) else version
248+
items_dict["v_num"] = version
249+
250+
return items_dict

pytorch_lightning/callbacks/progress/rich_progress.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@ def render(self, task) -> RenderableType:
4646
class MetricsTextColumn(ProgressColumn):
4747
"""A column containing text."""
4848

49-
def __init__(self, trainer, stage):
49+
def __init__(self, trainer, pl_module, stage):
5050
self._trainer = trainer
51+
self._pl_module = pl_module
5152
self._stage = stage
5253
self._tasks = {}
5354
self._current_task_id = 0
@@ -64,7 +65,13 @@ def render(self, task) -> Text:
6465
if self._trainer.training and task.id != self._current_task_id:
6566
return self._tasks[task.id]
6667
_text = ""
67-
for k, v in self._trainer.progress_bar_dict.items():
68+
# TODO(@daniellepintz): make this code cleaner
69+
progress_bar_callback = getattr(self._trainer, "progress_bar_callback", None)
70+
if progress_bar_callback:
71+
metrics = self._trainer.progress_bar_callback.get_metrics(self._trainer, self._pl_module)
72+
else:
73+
metrics = self._trainer.progress_bar_metrics
74+
for k, v in metrics.items():
6875
_text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
6976
text = Text.from_markup(_text, style=None, justify="left")
7077
return text
@@ -163,7 +170,7 @@ def setup(self, trainer, pl_module, stage):
163170
"[",
164171
CustomTimeColumn(),
165172
ProcessingSpeedColumn(),
166-
MetricsTextColumn(trainer, stage),
173+
MetricsTextColumn(trainer, pl_module, stage),
167174
"]",
168175
console=self.console,
169176
refresh_per_second=self.refresh_rate,

0 commit comments

Comments
 (0)