Skip to content

Remove deprecated distributed_backend from Trainer #10017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 19, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `should_rank_save_checkpoint` property from Trainer ([#9433](https://github.com/PyTorchLightning/pytorch-lightning/pull/9433))


- Remove deprecated `distributed_backend` from `Trainer` ([#10017](https://github.com/PyTorchLightning/pytorch-lightning/pull/10017))


### Fixed


Expand Down
6 changes: 1 addition & 5 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ accelerator

|

The accelerator backend to use (previously known as distributed_backend).
The accelerator backend to use:

- (``'dp'``) is DataParallel (split batch among GPUs of same machine)
- (``'ddp'``) is DistributedDataParallel (each gpu on each node trains, and syncs grads)
Expand Down Expand Up @@ -553,10 +553,6 @@ will need to be set up to use remote filepaths.
# default used by the Trainer
trainer = Trainer(default_root_dir=os.getcwd())

distributed_backend
^^^^^^^^^^^^^^^^^^^
Deprecated: This has been renamed ``accelerator``.

enable_checkpointing
^^^^^^^^^^^^^^^^^^^^

Expand Down
101 changes: 39 additions & 62 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def __init__(
devices,
tpu_cores,
ipus,
distributed_backend,
accelerator,
strategy: Optional[Union[str, TrainingTypePlugin]],
gpus,
Expand All @@ -113,7 +112,7 @@ def __init__(
self._accelerator_type = None

self.strategy = strategy.lower() if isinstance(strategy, str) else strategy
self.distributed_backend = distributed_backend or accelerator
self.accelerator = accelerator

self._init_deterministic(deterministic)

Expand Down Expand Up @@ -152,7 +151,7 @@ def __init__(

self.plugins = plugins

self._handle_accelerator_and_distributed_backend(distributed_backend, accelerator)
self._handle_accelerator_and_strategy()

self._validate_accelerator_and_devices()

Expand All @@ -176,10 +175,6 @@ def __init__(
self._training_type_plugin_resolved = False
self.accelerator = self.select_accelerator()

# override dist backend when using tpus
if self.use_tpu:
self.distributed_backend = "tpu"

# init flags for SLURM+DDP to work
self.world_size = 1
self.interactive_ddp_procs = []
Expand Down Expand Up @@ -207,7 +202,7 @@ def _init_deterministic(self, deterministic: bool) -> None:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

def select_accelerator_type(self) -> None:
if self.distributed_backend == "auto":
if self.accelerator == "auto":
if self.has_tpu:
self._accelerator_type = DeviceType.TPU
elif self.has_ipu:
Expand All @@ -217,34 +212,34 @@ def select_accelerator_type(self) -> None:
else:
self._set_devices_to_cpu_num_processes()
self._accelerator_type = DeviceType.CPU
elif self.distributed_backend == DeviceType.TPU:
elif self.accelerator == DeviceType.TPU:
if not self.has_tpu:
msg = "TPUs are not available" if not _TPU_AVAILABLE else "you didn't pass `tpu_cores` to `Trainer`"
raise MisconfigurationException(f"You passed `accelerator='tpu'`, but {msg}.")
self._accelerator_type = DeviceType.TPU
elif self.distributed_backend == DeviceType.IPU:
elif self.accelerator == DeviceType.IPU:
if not self.has_ipu:
msg = "IPUs are not available" if not _IPU_AVAILABLE else "you didn't pass `ipus` to `Trainer`"
raise MisconfigurationException(f"You passed `accelerator='ipu'`, but {msg}.")
self._accelerator_type = DeviceType.IPU
elif self.distributed_backend == DeviceType.GPU:
elif self.accelerator == DeviceType.GPU:
if not self.has_gpu:
msg = "you didn't pass `gpus` to `Trainer`" if torch.cuda.is_available() else "GPUs are not available"
raise MisconfigurationException(f"You passed `accelerator='gpu'`, but {msg}.")
self._accelerator_type = DeviceType.GPU
elif self.distributed_backend == DeviceType.CPU:
elif self.accelerator == DeviceType.CPU:
self._set_devices_to_cpu_num_processes()
self._accelerator_type = DeviceType.CPU

if self.distributed_backend in self.accelerator_types:
self.distributed_backend = None
if self.accelerator in self.accelerator_types:
self.accelerator = None

def _validate_accelerator_and_devices(self) -> None:
if self.distributed_backend not in self.accelerator_types and self.devices is not None:
if self.accelerator not in self.accelerator_types and self.devices is not None:
raise MisconfigurationException(
f"You passed `devices={self.devices}` but haven't specified"
" `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu')` for the devices mapping,"
f" got `accelerator={self.distributed_backend!r}`."
f" got `accelerator={self.accelerator!r}`."
)

def _validate_accelerator_type(self) -> None:
Expand All @@ -260,16 +255,16 @@ def _warn_if_devices_flag_ignored(self) -> None:
if self.devices is None:
return
devices_warning = f"The flag `devices={self.devices}` will be ignored, as you have set"
if self.distributed_backend in ("auto", DeviceType.TPU):
if self.accelerator in ("auto", DeviceType.TPU):
if self.tpu_cores is not None:
rank_zero_warn(f"{devices_warning} `tpu_cores={self.tpu_cores}`")
elif self.distributed_backend in ("auto", DeviceType.IPU):
elif self.accelerator in ("auto", DeviceType.IPU):
if self.ipus is not None:
rank_zero_warn(f"{devices_warning} `ipus={self.ipus}`")
elif self.distributed_backend in ("auto", DeviceType.GPU):
elif self.accelerator in ("auto", DeviceType.GPU):
if self.gpus is not None:
rank_zero_warn(f"{devices_warning} `gpus={self.gpus}`")
elif self.distributed_backend in ("auto", DeviceType.CPU):
elif self.accelerator in ("auto", DeviceType.CPU):
if self.num_processes != 1:
rank_zero_warn(f"{devices_warning} `num_processes={self.num_processes}`")

Expand All @@ -285,31 +280,16 @@ def _set_devices_if_none(self) -> None:
elif self._accelerator_type == DeviceType.CPU:
self.devices = self.num_processes

def _handle_accelerator_and_distributed_backend(
self, distributed_backend: Optional[str], accelerator: Optional[Union[str, Accelerator]]
) -> None:
if distributed_backend is not None:
def _handle_accelerator_and_strategy(self) -> None:
if self.accelerator is not None and self.accelerator in list(DistributedType):
rank_zero_deprecation(
f"`Trainer(distributed_backend={distributed_backend!r})` "
"has been deprecated and will be removed in v1.5."
f" Use `Trainer(strategy={distributed_backend!r})` instead."
f"Passing `Trainer(accelerator={self.accelerator!r})` has been deprecated"
f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={self.accelerator!r})` instead."
)
if self.strategy is not None:
raise MisconfigurationException(
f"You have passed `Trainer(strategy={self.strategy!r})` but have"
f" also passed `Trainer(distributed_backend={distributed_backend!r})`."
f" HINT: Use just `Trainer(strategy={self.strategy!r})` instead."
)

if accelerator is not None and accelerator in list(DistributedType):
rank_zero_deprecation(
f"Passing `Trainer(accelerator={accelerator!r})` has been deprecated"
f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={accelerator!r})` instead."
)
if self.strategy is not None:
raise MisconfigurationException(
f"You have passed `Trainer(strategy={self.strategy!r})` but have"
f" also passed `Trainer(accelerator={accelerator!r})`."
f" also passed `Trainer(accelerator={self.accelerator!r})`."
f" HINT: Use just `Trainer(strategy={self.strategy!r})` instead."
)

Expand Down Expand Up @@ -655,11 +635,8 @@ def select_precision_plugin(self) -> PrecisionPlugin:
return ApexMixedPrecisionPlugin(self.amp_level)

def select_training_type_plugin(self) -> TrainingTypePlugin:
if (
isinstance(self.distributed_backend, Accelerator)
and self.distributed_backend.training_type_plugin is not None
):
plugin = self.distributed_backend.training_type_plugin
if isinstance(self.accelerator, Accelerator) and self.accelerator.training_type_plugin is not None:
plugin = self.accelerator.training_type_plugin
elif self.use_ddp2:
plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment)
elif self.use_ddp and self.use_deepspeed:
Expand Down Expand Up @@ -741,15 +718,15 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra
return training_type

def select_accelerator(self) -> Accelerator:
if isinstance(self.distributed_backend, Accelerator):
if isinstance(self.accelerator, Accelerator):
# custom accelerator from user
if self._precision_plugin is not None or self._training_type_plugin is not None:
# plugins also specified by user
rank_zero_warn(
"Specified `Precision` and `TrainingType` plugins will be ignored,"
" since an `Accelerator` instance was provided."
)
return self.distributed_backend
return self.accelerator

if self.use_gpu:
acc_cls = GPUAccelerator
Expand Down Expand Up @@ -783,38 +760,38 @@ def select_cluster_environment(self) -> ClusterEnvironment:
env = LightningEnvironment()
return env

def set_distributed_mode(self, distributed_backend: Optional[str] = None):
def set_distributed_mode(self, strategy: Optional[str] = None):

if distributed_backend is None and self.is_training_type_in_plugins:
if strategy is None and self.is_training_type_in_plugins:
return

if distributed_backend is not None and distributed_backend in TrainingTypePluginsRegistry:
self.distributed_backend = TrainingTypePluginsRegistry[distributed_backend]["distributed_backend"]
elif distributed_backend is not None:
self.distributed_backend = distributed_backend
if strategy is not None and strategy in TrainingTypePluginsRegistry:
self.accelerator = TrainingTypePluginsRegistry[strategy]["distributed_backend"]
elif strategy is not None:
self.accelerator = strategy

if isinstance(self.distributed_backend, Accelerator):
if isinstance(self.accelerator, Accelerator):
return

is_cpu_accelerator_type = self._accelerator_type and self._accelerator_type == DeviceType.CPU
_use_cpu = is_cpu_accelerator_type or self.distributed_backend and "cpu" in self.distributed_backend
_use_cpu = is_cpu_accelerator_type or self.accelerator and "cpu" in self.accelerator

if self.distributed_backend is None:
if self.accelerator is None:
if self.has_horovodrun():
self._set_horovod_backend()
elif self.num_gpus == 0 and self.num_nodes > 1:
self._distrib_type = DistributedType.DDP
elif self.num_gpus == 0 and self.num_processes > 1:
self.distributed_backend = DistributedType.DDP_SPAWN
self.accelerator = DistributedType.DDP_SPAWN
elif self.num_gpus > 1 and not _use_cpu:
rank_zero_warn(
"You requested multiple GPUs but did not specify a backend, e.g."
' `Trainer(strategy="dp"|"ddp"|"ddp2")`. Setting `strategy="ddp_spawn"` for you.'
)
self.distributed_backend = DistributedType.DDP_SPAWN
self.accelerator = DistributedType.DDP_SPAWN

# special case with DDP on CPUs
if self.distributed_backend == DistributedType.DDP_CPU:
if self.accelerator == DistributedType.DDP_CPU:
if _TPU_AVAILABLE:
raise MisconfigurationException(
"`accelerator='ddp_cpu'` is not supported on TPU machines. "
Expand All @@ -839,8 +816,8 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
self._distrib_type = DistributedType.TPU_SPAWN
elif self.has_ipu and not _use_cpu:
self._device_type = DeviceType.IPU
elif self.distributed_backend and self._distrib_type is None:
self._distrib_type = DistributedType(self.distributed_backend)
elif self.accelerator and self._distrib_type is None:
self._distrib_type = DistributedType(self.accelerator)

if self.num_gpus > 0 and not _use_cpu:
self._device_type = DeviceType.GPU
Expand Down Expand Up @@ -873,7 +850,7 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
self.num_processes = self.num_nodes

# Horovod is an extra case...
if self.distributed_backend == DistributedType.HOROVOD:
if self.accelerator == DistributedType.HOROVOD:
self._set_horovod_backend()

using_valid_distributed = self.use_ddp or self.use_ddp2
Expand Down
11 changes: 1 addition & 10 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def __init__(
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
amp_backend: str = "native",
amp_level: Optional[str] = None,
distributed_backend: Optional[str] = None,
move_metrics_to_cpu: bool = False,
multiple_trainloader_mode: str = "max_size_cycle",
stochastic_weight_avg: bool = False,
Expand All @@ -187,7 +186,7 @@ def __init__(

Args:

accelerator: Previously known as distributed_backend (dp, ddp, ddp2, etc...).
accelerator: (dp, ddp, ddp2, etc...).
Can also take in an accelerator object for custom hardware.

accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.
Expand Down Expand Up @@ -241,8 +240,6 @@ def __init__(
devices: Will be mapped to either `gpus`, `tpu_cores`, `num_processes` or `ipus`,
based on the accelerator type.

distributed_backend: Deprecated. Please use ``accelerator``.

fast_dev_run: Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
of train, val and test to find any bugs (ie: a sort of unit test).

Expand Down Expand Up @@ -430,7 +427,6 @@ def __init__(
devices,
tpu_cores,
ipus,
distributed_backend,
accelerator,
strategy,
gpus,
Expand Down Expand Up @@ -1513,11 +1509,6 @@ def _on_exception(self):
def accelerator(self) -> Accelerator:
return self.accelerator_connector.accelerator

@property
def distributed_backend(self) -> Optional[str]:
# for backward compatibility
return self.accelerator_connector.distributed_backend

@property
def training_type_plugin(self) -> TrainingTypePlugin:
return self.accelerator.training_type_plugin
Expand Down
5 changes: 0 additions & 5 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,11 +634,6 @@ def test_accelerator_ddp_for_cpu(tmpdir):
assert isinstance(trainer.training_type_plugin, DDPPlugin)


def test_exception_when_strategy_used_with_distributed_backend():
with pytest.raises(MisconfigurationException, match="but have also passed"):
Trainer(distributed_backend="ddp_cpu", strategy="ddp_spawn")


def test_exception_when_strategy_used_with_accelerator():
with pytest.raises(MisconfigurationException, match="but have also passed"):
Trainer(accelerator="ddp", strategy="ddp_spawn")
Expand Down
22 changes: 0 additions & 22 deletions tests/deprecated_api/test_remove_1-5.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

@RunIf(min_gpus=2)
def test_multi_gpu_none_backend(tmpdir):
"""Make sure when using multiple GPUs the user can't use `distributed_backend = None`."""
"""Make sure when using multiple GPUs the user can't use `accelerator = None`."""
tutils.set_random_master_port()
trainer_options = dict(
default_root_dir=tmpdir,
Expand Down
11 changes: 2 additions & 9 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,6 @@

SERIAL_EXEC = xmp.MpSerialExecutor()

_LARGER_DATASET = RandomDataset(32, 2000)


# 8 cores needs a big dataset
def _serial_train_loader():
return DataLoader(_LARGER_DATASET, batch_size=32)


class SerialLoaderBoringModel(BoringModel):
def train_dataloader(self):
Expand Down Expand Up @@ -277,9 +270,9 @@ def test_exception_when_no_tpu_found(tmpdir):

@pytest.mark.parametrize("tpu_cores", [1, 8, [1]])
@RunIf(tpu=True)
def test_distributed_backend_set_when_using_tpu(tmpdir, tpu_cores):
def test_accelerator_set_when_using_tpu(tmpdir, tpu_cores):
"""Test if distributed_backend is set to `tpu` when tpu_cores is not None."""
assert Trainer(tpu_cores=tpu_cores).distributed_backend == "tpu"
assert isinstance(Trainer(tpu_cores=tpu_cores).accelerator, TPUAccelerator)


@RunIf(tpu=True)
Expand Down
Loading