Skip to content

Commit fe0d088

Browse files
Fix ShardedDataParallel has no attribute require_backward_grad_sync (#6915)
Co-authored-by: Kaushik B <[email protected]>
1 parent 20ff50c commit fe0d088

File tree

4 files changed

+52
-0
lines changed

4 files changed

+52
-0
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
237237
- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))
238238

239239

240+
- Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))
241+
242+
240243
## [1.2.7] - 2021-04-06
241244

242245
### Fixed

pytorch_lightning/plugins/training_type/sharded.py

+10
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# limitations under the License.
1414
from typing import Optional
1515

16+
import torch
17+
from torch.optim import Optimizer
18+
1619
from pytorch_lightning.core.lightning import LightningModule
1720
from pytorch_lightning.core.optimizer import is_lightning_optimizer
1821
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
@@ -33,6 +36,7 @@ def configure_ddp(self):
3336
self._model = ShardedDataParallel(
3437
LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers
3538
)
39+
setattr(self._model, "require_backward_grad_sync", False)
3640

3741
def _reinit_optimizers_with_oss(self):
3842
optimizers = self.lightning_module.trainer.optimizers
@@ -70,3 +74,9 @@ def _optim_state_dict(self, optimizer):
7074
@property
7175
def lightning_module(self) -> LightningModule:
7276
return unwrap_lightning_module_sharded(self._model)
77+
78+
def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
79+
pass
80+
81+
def post_training_step(self):
82+
pass

pytorch_lightning/plugins/training_type/sharded_spawn.py

+10
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# limitations under the License.
1414
from typing import Optional
1515

16+
import torch
17+
from torch.optim import Optimizer
18+
1619
from pytorch_lightning.core.lightning import LightningModule
1720
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
1821
from pytorch_lightning.trainer.states import TrainerState
@@ -32,6 +35,7 @@ def configure_ddp(self):
3235
self._model = ShardedDataParallel(
3336
LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers
3437
)
38+
setattr(self._model, "require_backward_grad_sync", False)
3539

3640
def _reinit_optimizers_with_oss(self):
3741
optimizers = self.lightning_module.trainer.optimizers
@@ -65,3 +69,9 @@ def _optim_state_dict(self, optimizer):
6569
@property
6670
def lightning_module(self) -> LightningModule:
6771
return unwrap_lightning_module_sharded(self._model)
72+
73+
def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
74+
pass
75+
76+
def post_training_step(self):
77+
pass

tests/plugins/test_sharded_plugin.py

+29
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,32 @@ def test_ddp_sharded_plugin_test_multigpu(tmpdir, trainer_kwargs):
278278

279279
trainer.validate(model)
280280
trainer.test(model)
281+
282+
283+
class ManualBoringModel(BoringModel):
284+
285+
def __init__(self):
286+
super().__init__()
287+
self.automatic_optimization = False
288+
289+
def training_step(self, batch, batch_idx):
290+
opt = self.optimizers()
291+
opt.zero_grad()
292+
output = self(batch)
293+
loss = self.loss(batch, output)
294+
self.manual_backward(loss)
295+
opt.step()
296+
return {"loss": loss}
297+
298+
299+
@RunIf(skip_windows=True, special=True, fairscale=True, min_gpus=2)
300+
@pytest.mark.parametrize("accelerator", ["ddp_sharded", "ddp_sharded_spawn"])
301+
def test_ddp_sharded_plugin_manual_optimization(tmpdir, accelerator):
302+
model = ManualBoringModel()
303+
trainer = Trainer(
304+
default_root_dir=tmpdir,
305+
accelerator=accelerator,
306+
fast_dev_run=2,
307+
gpus=2,
308+
)
309+
trainer.fit(model)

0 commit comments

Comments
 (0)