Skip to content

Commit 3a6acad

Browse files
committed
Undo test changes
1 parent b87c9ff commit 3a6acad

File tree

24 files changed

+203
-38
lines changed

24 files changed

+203
-38
lines changed

pytorch_lightning/callbacks/quantization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
if _TORCH_GREATER_EQUAL_1_8:
2929
from torch.quantization import FakeQuantizeBase
3030
else:
31-
# For torch 1.7.
31+
# For torch 1.6 and 1.7.
3232
from torch.quantization import FakeQuantize as FakeQuantizeBase
3333

3434
import pytorch_lightning as pl

pytorch_lightning/distributed/dist.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
# limitations under the License.
1414
from typing import Any
1515

16-
import torch.distributed
17-
16+
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
1817
from pytorch_lightning.utilities import rank_zero_deprecation
1918
from pytorch_lightning.utilities.distributed import group as _group
2019

@@ -41,6 +40,6 @@ def broadcast(self, obj: Any, group=_group.WORLD):
4140
if self.rank != 0:
4241
obj = [None] * len(obj)
4342

44-
torch.distributed.broadcast_object_list(obj, 0, group=group or _group.WORLD)
43+
broadcast_object_list(obj, 0, group=group or _group.WORLD)
4544

4645
return obj[0]
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import logging
2+
import pickle
3+
4+
import torch
5+
6+
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
7+
8+
log = logging.getLogger(__name__)
9+
10+
if torch.distributed.is_available():
11+
from torch.distributed import Backend, broadcast, get_backend, get_rank, GroupMember
12+
13+
# The code underneath is taken from PyTorch `torch/distributed/distributed_c10d.py`
14+
# and enable broadcasting for PyTorch 1.6 and lower.
15+
16+
17+
# https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L160
18+
def _rank_not_in_group(group):
19+
"""Helper that checks if the current process's rank is not in a given group."""
20+
if group is None:
21+
return False
22+
return group == GroupMember.NON_GROUP_MEMBER
23+
24+
25+
# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1164
26+
def _object_to_tensor(obj):
27+
buffer = pickle.dumps(obj)
28+
byte_storage = torch.ByteStorage.from_buffer(buffer) # type: ignore[attr-defined]
29+
byte_tensor = torch.ByteTensor(byte_storage)
30+
local_size = torch.LongTensor([byte_tensor.numel()])
31+
return byte_tensor, local_size
32+
33+
34+
# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py
35+
def _tensor_to_object(tensor, tensor_size):
36+
buf = tensor.numpy().tobytes()[:tensor_size]
37+
out = pickle.loads(buf)
38+
return out
39+
40+
41+
# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1327
42+
def _broadcast_object_list(object_list, src=0, group=None):
43+
if _rank_not_in_group(group):
44+
return
45+
46+
my_rank = get_rank()
47+
# Serialize object_list elements to tensors on src rank.
48+
if my_rank == src:
49+
tensor_list, size_list = zip(*(_object_to_tensor(obj) for obj in object_list))
50+
object_sizes_tensor = torch.cat(size_list)
51+
else:
52+
object_sizes_tensor = torch.LongTensor(len(object_list))
53+
54+
group_backend = get_backend(group)
55+
is_nccl_backend = group_backend == Backend.NCCL
56+
current_device = torch.device("cpu")
57+
if is_nccl_backend:
58+
# See note about using torch.cuda.current_device() here in docstring.
59+
# We cannot simply use my_rank since rank == device is not necessarily
60+
# true.
61+
current_device = torch.device("cuda", torch.cuda.current_device())
62+
object_sizes_tensor = object_sizes_tensor.to(current_device)
63+
object_sizes_tensor = object_sizes_tensor.to(current_device)
64+
65+
# Broadcast object sizes
66+
broadcast(object_sizes_tensor, src=src, group=group)
67+
68+
# Concatenate and broadcast serialized object tensors
69+
if my_rank == src:
70+
object_tensor = torch.cat(tensor_list)
71+
else:
72+
object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item())
73+
74+
if is_nccl_backend:
75+
object_tensor = object_tensor.to(current_device)
76+
77+
broadcast(object_tensor, src=src, group=group)
78+
79+
# Deserialize objects using their stored sizes.
80+
offset = 0
81+
if my_rank != src:
82+
for i, obj_size in enumerate(object_sizes_tensor):
83+
obj_view = object_tensor[offset : offset + obj_size]
84+
obj_view = obj_view.type(torch.ByteTensor) # type: ignore[call-overload]
85+
offset += obj_size
86+
object_list[i] = _tensor_to_object(obj_view, obj_size)
87+
88+
89+
if not torch.distributed.is_available():
90+
# avoid failures on early PyTorch versions for Windows where
91+
# not all functions used in `broadcast_object_list` are available.
92+
def _broadcast_noop(obj, *_, **__):
93+
return obj
94+
95+
broadcast_object_list = _broadcast_noop
96+
elif _TORCH_GREATER_EQUAL_1_8:
97+
from torch.distributed.distributed_c10d import broadcast_object_list
98+
else:
99+
broadcast_object_list = _broadcast_object_list

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from pytorch_lightning.core.optimizer import LightningOptimizer
3535
from pytorch_lightning.overrides import LightningDistributedModule
3636
from pytorch_lightning.overrides.distributed import prepare_for_backward
37+
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
3738
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
3839
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
3940
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
@@ -42,6 +43,7 @@
4243
_FAIRSCALE_AVAILABLE,
4344
_HYDRA_AVAILABLE,
4445
_IS_WINDOWS,
46+
_TORCH_GREATER_EQUAL_1_7,
4547
_TORCH_GREATER_EQUAL_1_8,
4648
_TORCH_GREATER_EQUAL_1_9,
4749
_TORCH_GREATER_EQUAL_1_10,
@@ -285,12 +287,15 @@ def pre_configure_ddp(self):
285287
# when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True.
286288
# This flag does come with a performance hit, so it is suggested to disable in cases where it is possible.
287289
self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True)
288-
if not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get(
289-
"find_unused_parameters", False
290+
# todo: PyTorch 1.7.0 DDP introduces `self.reducer._rebuild_buckets()` breaking manual_optimization
291+
if (
292+
_TORCH_GREATER_EQUAL_1_7
293+
and not self.lightning_module.automatic_optimization
294+
and not self._ddp_kwargs.get("find_unused_parameters", False)
290295
):
291-
# TODO: PyTorch 1.7.0 DDP introduces `self.reducer._rebuild_buckets()` breaking manual_optimization
292296
rank_zero_warn(
293-
"Lightning `manual_optimization` needs to set `find_unused_parameters=True` to properly work with DDP."
297+
"From PyTorch 1.7.0, Lightning ``manual_optimization`` needs to set ``find_unused_parameters=True`` "
298+
"to properly work with DDP."
294299
)
295300
self._ddp_kwargs["find_unused_parameters"] = True
296301

@@ -393,7 +398,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
393398
obj = [obj]
394399
if self.global_rank != src:
395400
obj = [None]
396-
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
401+
broadcast_object_list(obj, src, group=_group.WORLD)
397402
return obj[0]
398403

399404
def pre_backward(self, closure_loss: torch.Tensor) -> None:

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@
2727
import pytorch_lightning as pl
2828
from pytorch_lightning.overrides import LightningDistributedModule
2929
from pytorch_lightning.overrides.distributed import prepare_for_backward
30+
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
3031
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
3132
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
3233
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
3334
from pytorch_lightning.trainer.states import TrainerFn
34-
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn
35+
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8, rank_zero_warn
3536
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
3637
from pytorch_lightning.utilities.cloud_io import atomic_save
3738
from pytorch_lightning.utilities.cloud_io import load as pl_load
@@ -245,12 +246,15 @@ def pre_configure_ddp(self):
245246
# when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True.
246247
# This flag does come with a performance hit, so it is suggested to disable in cases where it is possible.
247248
self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True)
248-
if not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get(
249-
"find_unused_parameters", False
249+
# todo: PyTorch 1.7.0 DDP introduces `self.reducer._rebuild_buckets()` breaking manual_optimization
250+
if (
251+
_TORCH_GREATER_EQUAL_1_7
252+
and not self.lightning_module.automatic_optimization
253+
and not self._ddp_kwargs.get("find_unused_parameters", False)
250254
):
251-
# TODO: PyTorch 1.7.0 DDP introduces `self.reducer._rebuild_buckets()` breaking manual_optimization
252255
rank_zero_warn(
253-
"Lightning `manual_optimization` needs to set `find_unused_parameters=True` to properly work with DDP."
256+
"From PyTorch 1.7.0, Lightning ``manual_optimization`` needs to set ``find_unused_parameters=True`` "
257+
"to properly work with DDP."
254258
)
255259
self._ddp_kwargs["find_unused_parameters"] = True
256260

@@ -327,7 +331,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
327331
obj = [obj]
328332
if self.global_rank != src:
329333
obj = [None]
330-
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
334+
broadcast_object_list(obj, src, group=_group.WORLD)
331335
return obj[0]
332336

333337
def model_to_device(self):

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
from pytorch_lightning.utilities.imports import (
7575
_HOROVOD_AVAILABLE,
7676
_IPU_AVAILABLE,
77+
_TORCH_GREATER_EQUAL_1_7,
7778
_TORCH_GREATER_EQUAL_1_8,
7879
_TPU_AVAILABLE,
7980
)
@@ -189,8 +190,10 @@ def _init_deterministic(self, deterministic: bool) -> None:
189190
self.deterministic = deterministic
190191
if _TORCH_GREATER_EQUAL_1_8:
191192
torch.use_deterministic_algorithms(deterministic)
192-
else:
193+
elif _TORCH_GREATER_EQUAL_1_7:
193194
torch.set_deterministic(deterministic)
195+
else: # the minimum version Lightning supports is PyTorch 1.6
196+
torch._set_deterministic(deterministic)
194197
if deterministic:
195198
# fixing non-deterministic part of horovod
196199
# https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383

pytorch_lightning/utilities/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
_OMEGACONF_AVAILABLE,
4545
_POPTORCH_AVAILABLE,
4646
_RICH_AVAILABLE,
47+
_TORCH_GREATER_EQUAL_1_7,
4748
_TORCH_GREATER_EQUAL_1_8,
4849
_TORCH_GREATER_EQUAL_1_9,
4950
_TORCH_GREATER_EQUAL_1_10,

pytorch_lightning/utilities/auto_restart.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,9 @@ def _wrap_generator_samplers(self) -> None:
305305
# access wrapped dataset attributes
306306
dataset_dict = self.dataset.__dict__
307307

308+
# create a tuple of sampler names
309+
samplers_names = tuple(v.__class__.__name__ for k, v in dataset_dict.items() if isinstance(v, Sampler))
310+
308311
# create a dictionary of generator present within the dataset attributes
309312
dataset_sampler_generators = {k: v for k, v in dataset_dict.items() if isinstance(v, (Generator, Iterator))}
310313

@@ -315,17 +318,31 @@ def _wrap_generator_samplers(self) -> None:
315318
if isinstance(generator, Sampler):
316319
continue
317320

318-
# wrap the generator into a `FastForwardSampler`
319-
sampler = FastForwardSampler(generator, attr_name=generator_attr_name)
321+
# used to handle a weird behaviour from PyTorch 1.6
322+
# where the sampler is converted to a list_iterator
323+
is_legacy = False
324+
325+
if isinstance(generator, Generator):
326+
# Generator name have the the form `SamplerName.__iter__`
327+
generator_name = generator.__qualname__.split(".")[0]
328+
else:
329+
# assume the retrieved iterator is coming from sampler.
330+
is_legacy = True
331+
332+
# validate the base generator name matches a sampler name.
333+
if is_legacy or any(sampler_name == generator_name for sampler_name in samplers_names):
334+
335+
# wrap the generator into a `FastForwardSampler`
336+
sampler = FastForwardSampler(generator, attr_name=generator_attr_name)
320337

321-
# if `CaptureIterableDataset` was available, the sampler should reload its own state.
322-
if self._state_dict is not None:
323-
sampler.load_state_dict(self._state_dict[generator_attr_name])
324-
# store the samplers
325-
self.samplers[generator_attr_name] = sampler
338+
# if `CaptureIterableDataset` was available, the sampler should reload its own state.
339+
if self._state_dict is not None:
340+
sampler.load_state_dict(self._state_dict[generator_attr_name])
341+
# store the samplers
342+
self.samplers[generator_attr_name] = sampler
326343

327-
# replace generator with the generator from the `FastForwardSampler`.
328-
dataset_dict[generator_attr_name] = iter(sampler)
344+
# replace generator with the generator from the `FastForwardSampler`.
345+
dataset_dict[generator_attr_name] = iter(sampler)
329346

330347
self.reset_on_epoch()
331348

pytorch_lightning/utilities/cloud_io.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import fsspec
2020
import torch
2121
from fsspec.implementations.local import AbstractFileSystem, LocalFileSystem
22+
from packaging.version import Version
2223

2324

2425
def load(
@@ -58,6 +59,12 @@ def atomic_save(checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None:
5859
"""
5960

6061
bytesbuffer = io.BytesIO()
61-
torch.save(checkpoint, bytesbuffer)
62+
# Can't use the new zipfile serialization for 1.6.0 because there's a bug in
63+
# torch.hub.load_state_dict_from_url() that prevents it from loading the new files.
64+
# More details can be found here: https://github.com/pytorch/pytorch/issues/42239
65+
if Version(torch.__version__).release[:3] == (1, 6, 0):
66+
torch.save(checkpoint, bytesbuffer, _use_new_zipfile_serialization=False)
67+
else:
68+
torch.save(checkpoint, bytesbuffer)
6269
with fsspec.open(filepath, "wb") as f:
6370
f.write(bytesbuffer.getvalue())

pytorch_lightning/utilities/imports.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version:
7070

7171
_IS_WINDOWS = platform.system() == "Windows"
7272
_IS_INTERACTIVE = hasattr(sys, "ps1") # https://stackoverflow.com/a/64523765
73+
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
7374
_TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0")
7475
_TORCH_GREATER_EQUAL_1_8_1 = _compare_version("torch", operator.ge, "1.8.1")
7576
_TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0")
@@ -111,4 +112,4 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version:
111112

112113
# experimental feature within PyTorch Lightning.
113114
def _fault_tolerant_training() -> bool:
114-
return bool(int(os.getenv("PL_FAULT_TOLERANT_TRAINING", 0)))
115+
return _TORCH_GREATER_EQUAL_1_7 and int(os.getenv("PL_FAULT_TOLERANT_TRAINING", 0))

pytorch_lightning/utilities/seed.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import numpy as np
2222
import torch
2323

24-
from pytorch_lightning.utilities import rank_zero_warn
24+
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, rank_zero_warn
2525
from pytorch_lightning.utilities.distributed import rank_zero_only
2626

2727
log = logging.getLogger(__name__)
@@ -113,7 +113,9 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:
113113
np.random.seed(ss.generate_state(4))
114114
# Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module
115115
torch_ss, stdlib_ss = ss.spawn(2)
116-
torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0])
116+
# PyTorch 1.7 and above takes a 64-bit seed
117+
dtype = np.uint64 if _TORCH_GREATER_EQUAL_1_7 else np.uint32
118+
torch.manual_seed(torch_ss.generate_state(1, dtype=dtype)[0])
117119
# use 128 bits expressed as an integer
118120
stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum()
119121
random.seed(stdlib_seed)

tests/callbacks/test_quantization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
if _TORCH_GREATER_EQUAL_1_8:
3232
from torch.quantization import FakeQuantizeBase
3333
else:
34-
# For torch 1.7.
34+
# For torch 1.6 and 1.7.
3535
from torch.quantization import FakeQuantize as FakeQuantizeBase
3636

3737

tests/conftest.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch.distributed
2323

2424
from pytorch_lightning.plugins.environments.lightning_environment import find_free_network_port
25-
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
25+
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8
2626
from tests import _PATH_DATASETS
2727

2828

@@ -95,8 +95,10 @@ def reset_deterministic_algorithm():
9595
yield
9696
if _TORCH_GREATER_EQUAL_1_8:
9797
torch.use_deterministic_algorithms(False)
98-
else:
98+
elif _TORCH_GREATER_EQUAL_1_7:
9999
torch.set_deterministic(False)
100+
else: # the minimum version Lightning supports is PyTorch 1.6
101+
torch._set_deterministic(False)
100102

101103

102104
@pytest.fixture

0 commit comments

Comments
 (0)