|
22 | 22 | from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
23 | 23 |
|
24 | 24 | if _FAIRSCALE_AVAILABLE:
|
25 |
| - from fairscale.nn.data_parallel import FullyShardedDataParallel |
| 25 | + from fairscale.nn.data_parallel.fully_sharded_data_parallel import ( |
| 26 | + FullyShardedDataParallel, Parameter, TrainingState) |
26 | 27 |
|
27 | 28 | from pytorch_lightning.overrides.fairscale import (
|
28 | 29 | LightningFullShardedDataParallel,
|
29 | 30 | unwrap_lightning_module_full_sharded,
|
30 | 31 | )
|
31 | 32 |
|
32 | 33 |
|
| 34 | +class LightningFullyShardedDataParallel(FullyShardedDataParallel): |
| 35 | + def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> None: |
| 36 | + """Hook to call on each param after the reduce-scatter.""" |
| 37 | + assert torch.cuda.current_stream() == self._streams["post_backward"] |
| 38 | + assert param.grad is not None |
| 39 | + self.assert_state(TrainingState.BACKWARD) |
| 40 | + param.grad.data = reduced_grad |
| 41 | + # Cast grad to param's dtype (typically FP32). Note: we do this |
| 42 | + # before the move_grads_to_cpu step so that this entire hook remains |
| 43 | + # non-blocking. The downside is a bit more D2H transfer in that case. |
| 44 | + if self.mixed_precision: |
| 45 | + param.grad.data = param.grad.data.to(dtype=param.data.dtype) |
| 46 | + # Optionally move gradients to CPU, typically used if one is running |
| 47 | + # the optimizer on the CPU. |
| 48 | + # issues with this part |
| 49 | + |
| 50 | + # This part needs to be done after unscaling the gradients. |
| 51 | + #if self.move_grads_to_cpu: |
| 52 | + # param._cpu_grad.copy_(param.grad.data, non_blocking=True) |
| 53 | + # param.grad.data = param._cpu_grad |
| 54 | + # Don't let this memory get reused until after the transfers. |
| 55 | + #reduced_grad.record_stream(torch.cuda.current_stream()) |
| 56 | + |
| 57 | + |
33 | 58 | class FullShardedPlugin(DDPPlugin):
|
34 | 59 |
|
35 | 60 | def __init__(
|
36 | 61 | self,
|
37 | 62 | cpu_offload: bool = True,
|
38 | 63 | flatten_parameters: bool = False,
|
39 |
| - reshard_after_forward: bool = True, |
40 |
| - move_grads_to_cpu: Optional[bool] = None, |
41 |
| - fp32_reduce_scatter: Optional[bool] = None, |
| 64 | + reshard_after_forward: bool = False, |
| 65 | + fp32_reduce_scatter: Optional[bool] = False, |
42 | 66 | compute_dtype: Optional[torch.dtype] = None,
|
43 | 67 | bucket_cap_mb: int = 25,
|
44 | 68 | parallel_devices: Optional[List[torch.device]] = None,
|
45 | 69 | num_nodes: int = 1,
|
46 | 70 | cluster_environment: ClusterEnvironment = None,
|
47 |
| - sync_batchnorm: Optional[bool] = False |
| 71 | + sync_batchnorm: Optional[bool] = False, |
48 | 72 | ):
|
49 | 73 | """
|
50 | 74 |
|
@@ -72,7 +96,7 @@ def __init__(
|
72 | 96 |
|
73 | 97 | cpu_offload: Offload FP32 params to CPU. Only useable in precision=16 mode (default: False).
|
74 | 98 |
|
75 |
| - move_grads_to_cpu: Moves gradient shards to CPU after reducation. |
| 99 | + move_grads_to_cpu: Moves gradient shards to CPU after reduction. |
76 | 100 | Only disable if using CPU based optimizers (defaults to ``cpu_offload``).
|
77 | 101 |
|
78 | 102 | flatten_parameters: Flattens parameter into single contiguous tensor for speed efficiency
|
@@ -105,26 +129,27 @@ def __init__(
|
105 | 129 | raise MisconfigurationException("Currently sync batch norm is not supported by Full Sharded Training.")
|
106 | 130 | super().__init__(parallel_devices, num_nodes, cluster_environment, sync_batchnorm=sync_batchnorm)
|
107 | 131 | self.cpu_offload = cpu_offload
|
108 |
| - self.move_grads_to_cpu = move_grads_to_cpu |
109 | 132 | self.flatten_parameters = flatten_parameters
|
110 | 133 | self.reshard_after_forward = reshard_after_forward
|
111 | 134 | self.fp32_reduce_scatter = fp32_reduce_scatter
|
112 | 135 | self.compute_dtype = compute_dtype
|
113 | 136 | self.bucket_cap_mb = bucket_cap_mb
|
114 | 137 |
|
115 | 138 | def configure_ddp(self):
|
116 |
| - precision = self.lightning_module.trainer.precision |
| 139 | + trainer = self.lightning_module.trainer |
| 140 | + precision = trainer.precision |
117 | 141 | self.model = FullyShardedDataParallel(
|
118 | 142 | LightningFullShardedDataParallel(self.model),
|
119 | 143 | cpu_offload=self.cpu_offload,
|
120 |
| - move_grads_to_cpu=self.move_grads_to_cpu, |
| 144 | + move_grads_to_cpu=self.cpu_offload, |
121 | 145 | flatten_parameters=self.flatten_parameters,
|
122 | 146 | mixed_precision=precision == "mixed",
|
123 | 147 | reshard_after_forward=self.reshard_after_forward,
|
124 | 148 | fp32_reduce_scatter=self.fp32_reduce_scatter,
|
125 | 149 | compute_dtype=self.compute_dtype,
|
126 | 150 | bucket_cap_mb=self.bucket_cap_mb,
|
127 | 151 | )
|
| 152 | + trainer.accelerator.setup_optimizers(trainer) |
128 | 153 |
|
129 | 154 | @property
|
130 | 155 | def lightning_module(self) -> LightningModule:
|
|
0 commit comments