Skip to content

Move sync code from step result to lightning module [6/n] #7651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 24, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import copy
import inspect
import logging
import numbers
import os
import tempfile
import types
Expand All @@ -42,10 +43,11 @@
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.types import _METRIC, EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()
Expand Down Expand Up @@ -325,6 +327,15 @@ def log(
f"Logged key: {name} should not contain information about dataloader_idx."
)

value = self._sync(
value,
sync_fn=self.trainer.training_type_plugin.reduce,
sync_dist=sync_dist,
sync_dist_op=sync_dist_op,
sync_dist_group=sync_dist_group,
device=self.device,
)

self._results.log(
name,
value,
Expand All @@ -336,12 +347,7 @@ def log(
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
enable_graph=enable_graph,
sync_dist=sync_dist,
sync_dist_op=sync_dist_op,
sync_dist_group=sync_dist_group,
sync_fn=self.trainer.training_type_plugin.reduce,
dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None),
device=self.device,
)

def log_dict(
Expand Down Expand Up @@ -403,6 +409,31 @@ def log_dict(
add_dataloader_idx=add_dataloader_idx
)

@staticmethod
def __sync(
Comment on lines +412 to +413
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not move this out of the LighningModule completely to a utilities file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think other places might use it?
Do you have an idea of where it would fit better?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it could be part of #7534

But this is definitely not a blocker for this refactoring

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, feel free to include it when that's done

value: _METRIC,
sync_fn: Optional[Callable] = None,
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
sync_dist_group: Optional[Any] = None,
device: torch.device = None,
) -> _METRIC:
"""Sync across workers when using distributed training"""
if not isinstance(value, (torch.Tensor, numbers.Number)):
return value

sync_fn = sync_fn or sync_ddp_if_available
dist_available = torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed()
if not sync_dist or not dist_available:
return value

# TODO: Find a way to make the reduction only once, so we don't need to clone.
if isinstance(value, torch.Tensor):
value = value.clone()
else:
value = torch.tensor(value, device=device, dtype=torch.float)
return sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op)

def write_prediction(
self, name: str, value: Union[torch.Tensor, List[torch.Tensor]], filename: str = 'predictions.pt'
):
Expand Down
20 changes: 0 additions & 20 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,13 @@
# limitations under the License.
"""Result class for easier logging and epoch-wise reduction."""

import numbers
from copy import copy
from typing import Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Sequence, Tuple, Union

import torch
from torch import Tensor
from torchmetrics import Metric

from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed


class Result(Dict):

Expand Down Expand Up @@ -88,29 +85,12 @@ def log(
tbptt_reduce_fx: Callable = torch.mean,
tbptt_pad_token: int = 0,
enable_graph: bool = False,
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
sync_dist_group: Optional[Any] = None,
sync_fn: Callable = None,
dataloader_idx: Optional[int] = None,
device: torch.device = None,
):
# no metrics should be logged with graphs
if not enable_graph and isinstance(value, torch.Tensor):
value = value.detach()

# sync across workers when using distributed training
sync_fn = sync_fn or sync_ddp_if_available

if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
# TODO: Find a way to make the reduction only once, so we don't need to clone.
if (is_dist_initialized or tpu_distributed()) and isinstance(value, torch.Tensor):
value = value.clone()
else:
value = torch.tensor(value, device=device, dtype=torch.float)
value = sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op)

if isinstance(value, torch.Tensor) and value.device.type == "xla":
value = value.cpu()

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
- Do not include any `_TYPE` suffix
- Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no trailing `_`)
"""
from numbers import Number
from typing import Any, Dict, Iterator, List, Union

import torch
from torchmetrics import Metric

_METRIC = Union[Metric, torch.Tensor, int, float]
_METRIC = Union[Metric, torch.Tensor, Number]
STEP_OUTPUT = Union[torch.Tensor, Dict[str, Any]]
EPOCH_OUTPUT = List[STEP_OUTPUT]
_EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader
Expand Down