Skip to content

Commit 59dbb83

Browse files
committed
update
1 parent c36e00a commit 59dbb83

File tree

2 files changed

+42
-9
lines changed

2 files changed

+42
-9
lines changed

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def backward(
6161
# unscale gradient to allow analyze within `on_after_backward`
6262
if not should_accumulate and model.automatic_optimization:
6363
self.scaler.unscale_(optimizer)
64+
self.move_grad_to_cpu(model.trainer.model)
6465

6566
return closure_loss
6667

@@ -88,6 +89,13 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
8889
self.scaler.step(optimizer)
8990
self.scaler.update()
9091

92+
def move_grad_to_cpu(self, model):
93+
if hasattr(model, "cpu_offload"):
94+
if model.cpu_offload:
95+
for param in model.params:
96+
param._cpu_grad.copy_(param.grad.data, non_blocking=True)
97+
param.grad.data = param._cpu_grad
98+
9199
@contextmanager
92100
def train_step_context(self) -> Generator[autocast, None, None]:
93101
"""Enable autocast context"""

pytorch_lightning/plugins/training_type/full_sharded.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,29 +22,53 @@
2222
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2323

2424
if _FAIRSCALE_AVAILABLE:
25-
from fairscale.nn.data_parallel import FullyShardedDataParallel
25+
from fairscale.nn.data_parallel.fully_sharded_data_parallel import (
26+
FullyShardedDataParallel, Parameter, TrainingState)
2627

2728
from pytorch_lightning.overrides.fairscale import (
2829
LightningFullShardedDataParallel,
2930
unwrap_lightning_module_full_sharded,
3031
)
3132

3233

34+
class LightningFullyShardedDataParallel(FullyShardedDataParallel):
35+
def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
36+
"""Hook to call on each param after the reduce-scatter."""
37+
assert torch.cuda.current_stream() == self._streams["post_backward"]
38+
assert param.grad is not None
39+
self.assert_state(TrainingState.BACKWARD)
40+
param.grad.data = reduced_grad
41+
# Cast grad to param's dtype (typically FP32). Note: we do this
42+
# before the move_grads_to_cpu step so that this entire hook remains
43+
# non-blocking. The downside is a bit more D2H transfer in that case.
44+
if self.mixed_precision:
45+
param.grad.data = param.grad.data.to(dtype=param.data.dtype)
46+
# Optionally move gradients to CPU, typically used if one is running
47+
# the optimizer on the CPU.
48+
# issues with this part
49+
50+
# This part needs to be done after unscaling the gradients.
51+
#if self.move_grads_to_cpu:
52+
# param._cpu_grad.copy_(param.grad.data, non_blocking=True)
53+
# param.grad.data = param._cpu_grad
54+
# Don't let this memory get reused until after the transfers.
55+
#reduced_grad.record_stream(torch.cuda.current_stream())
56+
57+
3358
class FullShardedPlugin(DDPPlugin):
3459

3560
def __init__(
3661
self,
3762
cpu_offload: bool = True,
3863
flatten_parameters: bool = False,
39-
reshard_after_forward: bool = True,
40-
move_grads_to_cpu: Optional[bool] = None,
41-
fp32_reduce_scatter: Optional[bool] = None,
64+
reshard_after_forward: bool = False,
65+
fp32_reduce_scatter: Optional[bool] = False,
4266
compute_dtype: Optional[torch.dtype] = None,
4367
bucket_cap_mb: int = 25,
4468
parallel_devices: Optional[List[torch.device]] = None,
4569
num_nodes: int = 1,
4670
cluster_environment: ClusterEnvironment = None,
47-
sync_batchnorm: Optional[bool] = False
71+
sync_batchnorm: Optional[bool] = False,
4872
):
4973
"""
5074
@@ -72,7 +96,7 @@ def __init__(
7296
7397
cpu_offload: Offload FP32 params to CPU. Only useable in precision=16 mode (default: False).
7498
75-
move_grads_to_cpu: Moves gradient shards to CPU after reducation.
99+
move_grads_to_cpu: Moves gradient shards to CPU after reduction.
76100
Only disable if using CPU based optimizers (defaults to ``cpu_offload``).
77101
78102
flatten_parameters: Flattens parameter into single contiguous tensor for speed efficiency
@@ -105,26 +129,27 @@ def __init__(
105129
raise MisconfigurationException("Currently sync batch norm is not supported by Full Sharded Training.")
106130
super().__init__(parallel_devices, num_nodes, cluster_environment, sync_batchnorm=sync_batchnorm)
107131
self.cpu_offload = cpu_offload
108-
self.move_grads_to_cpu = move_grads_to_cpu
109132
self.flatten_parameters = flatten_parameters
110133
self.reshard_after_forward = reshard_after_forward
111134
self.fp32_reduce_scatter = fp32_reduce_scatter
112135
self.compute_dtype = compute_dtype
113136
self.bucket_cap_mb = bucket_cap_mb
114137

115138
def configure_ddp(self):
116-
precision = self.lightning_module.trainer.precision
139+
trainer = self.lightning_module.trainer
140+
precision = trainer.precision
117141
self.model = FullyShardedDataParallel(
118142
LightningFullShardedDataParallel(self.model),
119143
cpu_offload=self.cpu_offload,
120-
move_grads_to_cpu=self.move_grads_to_cpu,
144+
move_grads_to_cpu=self.cpu_offload,
121145
flatten_parameters=self.flatten_parameters,
122146
mixed_precision=precision == "mixed",
123147
reshard_after_forward=self.reshard_after_forward,
124148
fp32_reduce_scatter=self.fp32_reduce_scatter,
125149
compute_dtype=self.compute_dtype,
126150
bucket_cap_mb=self.bucket_cap_mb,
127151
)
152+
trainer.accelerator.setup_optimizers(trainer)
128153

129154
@property
130155
def lightning_module(self) -> LightningModule:

0 commit comments

Comments
 (0)