Skip to content

Commit d61a5ab

Browse files
committed
3/n Consolidate collective functions - Integrate with TTPs
1 parent 54310bc commit d61a5ab

24 files changed

+129
-324
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
338338
return self.training_type_plugin.lightning_module_state_dict()
339339

340340
def barrier(self, name: Optional[str] = None) -> None:
341-
self.training_type_plugin.barrier(name=name)
341+
self.training_type_plugin.collective.barrier(name=name)
342342

343343
def broadcast(self, obj: object, src: int = 0) -> object:
344344
"""Broadcasts an object to all processes, such that the src object is broadcast to all other ranks if
@@ -348,7 +348,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
348348
obj: Object to broadcast to all process, usually a tensor or collection of tensors.
349349
src: The source rank of which the object will be broadcast from
350350
"""
351-
return self.training_type_plugin.broadcast(obj, src)
351+
return self.training_type_plugin.collective.broadcast(obj, src)
352352

353353
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
354354
"""Function to gather a tensor from several distributed processes.
@@ -361,7 +361,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
361361
Return:
362362
A tensor of shape (world_size, batch, ...)
363363
"""
364-
return self.training_type_plugin.all_gather(tensor, group=group, sync_grads=sync_grads)
364+
return self.training_type_plugin.collective.all_gather(tensor, group=group, sync_grads=sync_grads)
365365

366366
def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
367367
"""Wraps the dataloader if necessary.

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
206206
should_stop, reason = self._evaluate_stopping_criteria(current)
207207

208208
# stop every ddp process if any world process decides to stop
209-
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
209+
should_stop = trainer.training_type_plugin.collective.reduce_boolean_decision(should_stop)
210210
trainer.should_stop = trainer.should_stop or should_stop
211211
if should_stop:
212212
self.stopped_epoch = trainer.current_epoch

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def on_train_batch_end(
294294
skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds()
295295
# in case we have time differences across ranks
296296
# broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangs
297-
skip_time = trainer.training_type_plugin.broadcast(skip_time)
297+
skip_time = trainer.training_type_plugin.collective.broadcast(skip_time)
298298

299299
if skip_batch and skip_time:
300300
return
@@ -509,7 +509,9 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[torch.Ten
509509
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])
510510

511511
# If using multiple devices, make sure all processes are unanimous on the decision.
512-
should_update_best_and_save = trainer.training_type_plugin.reduce_boolean_decision(should_update_best_and_save)
512+
should_update_best_and_save = trainer.training_type_plugin.collective.reduce_boolean_decision(
513+
should_update_best_and_save
514+
)
513515

514516
return should_update_best_and_save
515517

@@ -612,7 +614,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
612614
else:
613615
ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")
614616

615-
ckpt_path = trainer.training_type_plugin.broadcast(ckpt_path)
617+
ckpt_path = trainer.training_type_plugin.collective.broadcast(ckpt_path)
616618

617619
self.dirpath = ckpt_path
618620

@@ -748,4 +750,4 @@ def file_exists(self, filepath: Union[str, Path], trainer: "pl.Trainer") -> bool
748750
"""Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal
749751
state to diverge between ranks."""
750752
exists = self._fs.exists(filepath)
751-
return trainer.training_type_plugin.broadcast(exists)
753+
return trainer.training_type_plugin.collective.broadcast(exists)

pytorch_lightning/callbacks/xla_stats_monitor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def on_train_start(self, trainer, pl_module) -> None:
6767
)
6868

6969
memory_info = xm.get_memory_info(pl_module.device)
70-
total_memory = trainer.training_type_plugin.reduce(memory_info["kb_total"]) * 0.001
70+
total_memory = trainer.training_type_plugin.collective.reduce(memory_info["kb_total"]) * 0.001
7171
rank_zero_info(f"Average Total memory: {total_memory:.2f} MB")
7272

7373
def on_train_epoch_start(self, trainer, pl_module) -> None:
@@ -81,9 +81,9 @@ def on_train_epoch_end(self, trainer, pl_module) -> None:
8181
free_memory = memory_info["kb_free"]
8282
peak_memory = memory_info["kb_total"] - free_memory
8383

84-
free_memory = trainer.training_type_plugin.reduce(free_memory) * 0.001
85-
peak_memory = trainer.training_type_plugin.reduce(peak_memory) * 0.001
86-
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
84+
free_memory = trainer.training_type_plugin.collective.reduce(free_memory) * 0.001
85+
peak_memory = trainer.training_type_plugin.collective.reduce(peak_memory) * 0.001
86+
epoch_time = trainer.training_type_plugin.collective.reduce(epoch_time)
8787

8888
logs["avg. free memory (MB)"] = free_memory
8989
logs["avg. peak memory (MB)"] = peak_memory

pytorch_lightning/core/lightning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ def log(
466466
dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None),
467467
batch_size=batch_size,
468468
sync_dist=sync_dist and distributed_available(),
469-
sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp,
469+
sync_dist_fn=self.trainer.training_type_plugin.collective.reduce or sync_ddp,
470470
sync_dist_group=sync_dist_group,
471471
metric_attribute=metric_attribute,
472472
rank_zero_only=rank_zero_only,

pytorch_lightning/loops/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,9 @@ def _load_from_state_dict(
238238
# On reload, we need to re-attach the `Metric`s back to the `ResultCollection`.
239239
# The references are provided through the `metric_attributes` dictionary.
240240
v.load_state_dict(
241-
state_dict[prefix + k], metrics=metric_attributes, sync_fn=self.trainer.training_type_plugin.reduce
241+
state_dict[prefix + k],
242+
metrics=metric_attributes,
243+
sync_fn=self.trainer.training_type_plugin.collective.reduce,
242244
)
243245

244246
if not self.trainer.is_global_zero:

pytorch_lightning/plugins/collective/torch_collective.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,10 @@ def mean(t: torch.Tensor) -> torch.Tensor:
107107
return tensor
108108

109109
def reduce_boolean_decision(self, decision: bool) -> bool:
110-
decision = torch.tensor(int(decision), device=self.device)
111-
decision = self.reduce(decision, reduce_op=ReduceOp.SUM)
112-
decision = bool(decision == self.world_size)
113-
return decision
110+
if self.local_reduce:
111+
return decision
112+
else:
113+
decision1 = torch.tensor(int(decision), device=self.device)
114+
decision2 = self.reduce(decision1, reduce_op=ReduceOp.SUM)
115+
decision = bool(decision2 == self.world_size)
116+
return decision

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 12 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@
3131

3232
import pytorch_lightning as pl
3333
from pytorch_lightning.core.optimizer import LightningOptimizer
34-
from pytorch_lightning.distributed import LightningDistributed
3534
from pytorch_lightning.overrides import LightningDistributedModule
3635
from pytorch_lightning.overrides.distributed import prepare_for_backward
36+
from pytorch_lightning.plugins.collective.collective_plugin import Collective
37+
from pytorch_lightning.plugins.collective.torch_collective import TorchCollective
3738
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
3839
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
3940
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
@@ -48,13 +49,7 @@
4849
rank_zero_deprecation,
4950
rank_zero_warn,
5051
)
51-
from pytorch_lightning.utilities.distributed import (
52-
distributed_available,
53-
init_ddp_connection,
54-
rank_zero_only,
55-
ReduceOp,
56-
sync_ddp_if_available,
57-
)
52+
from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only
5853
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
5954
from pytorch_lightning.utilities.seed import reset_seed
6055
from pytorch_lightning.utilities.types import STEP_OUTPUT
@@ -91,6 +86,7 @@ def __init__(
9186
num_nodes: Optional[int] = None,
9287
cluster_environment: Optional[ClusterEnvironment] = None,
9388
checkpoint_io: Optional[CheckpointIO] = None,
89+
collective: Optional[Collective] = None,
9490
sync_batchnorm: Optional[bool] = None,
9591
ddp_comm_state: Optional[object] = None,
9692
ddp_comm_hook: Optional[callable] = None,
@@ -102,6 +98,7 @@ def __init__(
10298
parallel_devices=parallel_devices,
10399
cluster_environment=cluster_environment,
104100
checkpoint_io=checkpoint_io,
101+
collective=collective or TorchCollective(),
105102
)
106103
self.interactive_ddp_procs = []
107104
if num_nodes is not None:
@@ -116,7 +113,6 @@ def __init__(
116113
" Notice that it will be overriden by the trainer setting."
117114
)
118115
self._sync_batchnorm = sync_batchnorm or False
119-
self.dist = LightningDistributed()
120116
self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0
121117
self._ddp_kwargs = kwargs
122118
self._task_idx = None
@@ -267,8 +263,10 @@ def setup_distributed(self):
267263
init_ddp_connection(self.cluster_environment, self.torch_distributed_backend)
268264

269265
# set the ranks and devices
270-
self.dist.rank = self.global_rank
271-
self.dist.device = self.root_device
266+
self.collective.rank = self.global_rank
267+
self.collective.device = self.root_device
268+
self.collective.device_id = self.determine_ddp_device_ids()
269+
self.collective.world_size = self.world_size
272270

273271
def _check_can_spawn_children(self):
274272
if self.local_rank != 0:
@@ -389,17 +387,6 @@ def pre_dispatch(self):
389387
def post_dispatch(self, trainer: "pl.Trainer") -> None:
390388
self.cluster_environment.teardown()
391389

392-
def barrier(self, *args, **kwargs) -> None:
393-
if not distributed_available():
394-
return
395-
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
396-
torch.distributed.barrier(device_ids=self.determine_ddp_device_ids())
397-
else:
398-
torch.distributed.barrier()
399-
400-
def broadcast(self, obj: object, src: int = 0) -> object:
401-
return self.dist.broadcast(obj)
402-
403390
def pre_backward(self, closure_loss: torch.Tensor) -> None:
404391
"""Run before precision plugin executes backward."""
405392
if not self.lightning_module.automatic_optimization:
@@ -408,22 +395,6 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None:
408395
def model_to_device(self):
409396
self.model.to(self.root_device)
410397

411-
def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor:
412-
"""Reduces a tensor from several distributed processes to one aggregated tensor.
413-
414-
Args:
415-
tensor: the tensor to sync and reduce
416-
group: the process group to gather results from. Defaults to all processes (world)
417-
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
418-
Can also be a string 'sum' to calculate the sum during reduction.
419-
420-
Return:
421-
reduced value, except when the input was not a tensor the output remains is unchanged
422-
"""
423-
if isinstance(tensor, torch.Tensor):
424-
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
425-
return tensor
426-
427398
def training_step(self, *args, **kwargs) -> Optional[Any]:
428399
return self.model(*args, **kwargs)
429400

@@ -465,15 +436,15 @@ def _share_information_to_prevent_deadlock(self):
465436
sync_dirs = []
466437
global_node_rank_zero = 0
467438
for _ in range(self.num_nodes):
468-
sync_dirs.append(self.broadcast(self._sync_dir, global_node_rank_zero))
439+
sync_dirs.append(self.collective.broadcast(self._sync_dir, global_node_rank_zero))
469440
global_node_rank_zero += self.world_size // self.num_nodes
470441

471442
self._sync_dir = sync_dirs[self.node_rank]
472443

473444
def _share_pids(self):
474445
"""Make all DDP processes aware of all processes pids."""
475-
self.barrier()
476-
pids = self.all_gather(torch.tensor(os.getpid(), device=self.root_device))
446+
self.collective.barrier()
447+
pids = self.collective.all_gather(torch.tensor(os.getpid(), device=self.root_device))
477448
pids = pids.cpu().numpy().tolist()
478449
self._pids = pids if isinstance(pids, list) else [pids]
479450

pytorch_lightning/plugins/training_type/ddp2.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import torch
15-
1614
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
17-
from pytorch_lightning.utilities.apply_func import apply_to_collection
18-
from pytorch_lightning.utilities.types import _METRIC_COLLECTION
1915

2016

2117
class DDP2Plugin(DDPPlugin):
@@ -33,25 +29,7 @@ def setup(self) -> None:
3329
# set the task idx
3430
self.task_idx = self.cluster_environment.local_rank()
3531
# the difference to DDP is that we don't call children processes here
36-
37-
def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION:
38-
"""Reduces a collection of tensors from all processes. It can be applied to just a single tensor. In DDP2,
39-
the reduction here is only across local devices within the node.
40-
41-
Args:
42-
collection: The collection of tensors to sync and reduce.
43-
*args: ignored for DDP2
44-
**kwargs: ignored for DDP2
45-
46-
Return:
47-
Reduced tensor values or the same value if it was not or did not contain a tensor.
48-
"""
49-
50-
def mean(t: torch.Tensor) -> torch.Tensor:
51-
original_dtype = t.dtype
52-
return t.float().mean().to(original_dtype)
53-
54-
return apply_to_collection(collection, torch.Tensor, mean)
32+
self.collective.local_reduce = True
5533

5634
@property
5735
def root_device(self):

0 commit comments

Comments
 (0)