diff --git a/pl_examples/ipu_examples/__init__.py b/pl_examples/ipu_examples/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/pl_examples/ipu_examples/mnist.py b/pl_examples/ipu_examples/mnist.py new file mode 100644 index 0000000000000..c907f4a15af48 --- /dev/null +++ b/pl_examples/ipu_examples/mnist.py @@ -0,0 +1,84 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch.nn import functional as F + +import pytorch_lightning as pl +from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule + + +class LitClassifier(pl.LightningModule): + + def __init__( + self, + hidden_dim: int = 128, + learning_rate: float = 0.0001, + ): + super().__init__() + self.save_hyperparameters() + + self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim) + self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10) + + def forward(self, x): + x = x.view(x.size(0), -1) + x = torch.relu(self.l1(x)) + x = torch.relu(self.l2(x)) + return x + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + acc = self.accuracy(logits, y) + return acc + + def test_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + acc = self.accuracy(logits, y) + return acc + + def accuracy(self, logits, y): + # currently IPU poptorch doesn't implicit convert bools to tensor + # hence we use an explicit calculation for accuracy here. Once fixed in poptorch + # we can use the accuracy metric. + acc = torch.sum(torch.eq(torch.argmax(logits, -1), y).to(torch.float32)) / len(y) + return acc + + def validation_epoch_end(self, outputs) -> None: + self.log('val_acc', torch.stack(outputs).mean(), prog_bar=True) + + def test_epoch_end(self, outputs) -> None: + self.log('test_acc', torch.stack(outputs).mean()) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + + +if __name__ == '__main__': + dm = MNISTDataModule(batch_size=32) + + model = LitClassifier() + + trainer = pl.Trainer(max_epochs=2, ipu_cores=8) + + trainer.fit(model, datamodule=dm) + trainer.test(model, datamodule=dm) diff --git a/pytorch_lightning/accelerators/__init__.py b/pytorch_lightning/accelerators/__init__.py index 05e15fe1f1767..2a460a27e373a 100644 --- a/pytorch_lightning/accelerators/__init__.py +++ b/pytorch_lightning/accelerators/__init__.py @@ -13,4 +13,5 @@ from pytorch_lightning.accelerators.accelerator import Accelerator # noqa F401 from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa F401 from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa F401 +from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa F401 from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa F401 diff --git a/pytorch_lightning/accelerators/ipu.py b/pytorch_lightning/accelerators/ipu.py new file mode 100644 index 0000000000000..34bee31b5a91d --- /dev/null +++ b/pytorch_lightning/accelerators/ipu.py @@ -0,0 +1,34 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import Callable + +from torch.optim import Optimizer + +import pytorch_lightning as pl +from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class IPUAccelerator(Accelerator): + """ Accelerator for IPUs. """ + + def setup_optimizers(self, trainer: 'pl.Trainer') -> None: + super().setup_optimizers(trainer) + + if len(self.optimizers) > 1: + raise MisconfigurationException("IPUs currently only support one optimizer.") + + def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs): + # Optimizer step is handled by the IPU accelerator. + lambda_closure() diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index 58d43dc54cb7f..cc95671ebf2cc 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -9,6 +9,7 @@ from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401 FullyShardedNativeMixedPrecisionPlugin, ) +from pytorch_lightning.plugins.precision.ipu_precision import IPUPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401 @@ -20,6 +21,7 @@ from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.fully_sharded import DDPFullyShardedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.ipu import IPUPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin # noqa: F401 @@ -41,6 +43,8 @@ "DeepSpeedPrecisionPlugin", "DoublePrecisionPlugin", "HorovodPlugin", + "IPUPlugin", + "IPUPrecisionPlugin", "NativeMixedPrecisionPlugin", "PrecisionPlugin", "ShardedNativeMixedPrecisionPlugin", diff --git a/pytorch_lightning/plugins/precision/ipu_precision.py b/pytorch_lightning/plugins/precision/ipu_precision.py new file mode 100644 index 0000000000000..4e88a6cf73fe1 --- /dev/null +++ b/pytorch_lightning/plugins/precision/ipu_precision.py @@ -0,0 +1,24 @@ +from typing import Any + +from torch import Tensor + +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin + + +class IPUPrecisionPlugin(PrecisionPlugin): + + def __init__(self, precision: int) -> None: + super().__init__() + self.precision = precision + + def backward( + self, + closure_loss: Tensor, + *args: Any, + **kwargs: Any, + ) -> Tensor: + # IPU internally manages bwd step. + return closure_loss + + def clip_gradients(self, *args, **kwargs) -> None: + pass diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py new file mode 100644 index 0000000000000..2527470a12166 --- /dev/null +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -0,0 +1,327 @@ +import inspect +import json +import os +from typing import Any, Iterable, List, Optional, Union + +import torch +from torch.utils.data import DataLoader + +from pytorch_lightning import _logger as log +from pytorch_lightning.callbacks import GradientAccumulationScheduler +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin +from pytorch_lightning.trainer.supporters import CombinedLoader +from pytorch_lightning.utilities import _POPTORCH_AVAILABLE +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +if _POPTORCH_AVAILABLE: + import poptorch + + +class LightningIPUModule(_LightningModuleWrapperBase): + + def __init__(self, pl_module: LightningModule, precision: int): + super().__init__(pl_module) + self.precision = precision + + def forward(self, *inputs, **kwargs): + if self.precision == 16: + inputs = self._move_float_tensors_to_half(inputs) + + return super().forward(*inputs, **kwargs) + + @staticmethod + def batch_to(data): + return data.half() + + def _move_float_tensors_to_half(self, batch: Any): + batch = apply_to_collection(batch, (torch.FloatTensor, torch.cuda.FloatTensor), function=self.batch_to) + return batch + + +class IPUPlugin(ParallelPlugin): + """ + Plugin for training on IPU devices. + """ + + def __init__( + self, + device_iterations: int = 1, + autoround_num_ipus: bool = True, + autoreport: bool = True, + autoreport_dir: Optional[str] = None, + convert_model_to_half: bool = False, + parallel_devices: Optional[List[torch.device]] = None, + cluster_environment: Optional[ClusterEnvironment] = None, + ) -> None: + """ + Arguments: + + device_iterations: Number of iterations to run on device at once before returning to host. + This can be used as an optimization to speed up training. + https://docs.graphcore.ai/projects/poptorch-user-guide/en/0.1.67/batching.html + autoround_num_ipus: When selecting multiple IPUs, auto-rounds to powers of 2 as required for IPUs. + autoreport: Enable auto-reporting for IPUs using PopVision + https://docs.graphcore.ai/projects/graphcore-popvision-user-guide/en/latest/graph/graph.html + autoreport_dir: Optional directory to store autoReport output. + convert_model_to_half: Converts the model to half precision, which can be used for pure FP16 training. + """ + super().__init__(parallel_devices, cluster_environment) + self.convert_model_to_half = convert_model_to_half + self.device_iterations = device_iterations + self.autoround_num_ipus = autoround_num_ipus + self.autoreport = autoreport + self.autoreport_dir = autoreport_dir + self.poptorch_models = {} + self._original_accumulate_grad_batches = None + + if self.autoreport: + options = {"autoReport.all": self.autoreport} + if self.autoreport_dir: + if not os.path.exists(self.autoreport_dir): + os.makedirs(self.autoreport_dir) + options["autoReport.directory"] = self.autoreport_dir + os.environ["POPLAR_ENGINE_OPTIONS"] = json.dumps(options) + + def setup_environment(self) -> None: + super().setup_environment() + if not poptorch.ipuHardwareIsAvailable(): + raise MisconfigurationException("IPU Accelerator requires IPUs to run.") + + @property + def lightning_module(self) -> Optional[LightningModule]: + return self.model.module if isinstance(self.model, LightningIPUModule) else self.model + + def pre_dispatch(self) -> None: + self._handle_gradient_accumulation_steps() + if self.convert_model_to_half: + log.info('Using full 16bit precision, converting LightningModule weights to FP16.') + self.model = self.model.half() + precision = self.lightning_module.trainer.accelerator.precision_plugin.precision + precision = 16 if self.convert_model_to_half else precision + + model = LightningIPUModule(self.lightning_module, precision) + self.model = model + + # Separate models are instantiated for different stages, but they share the same weights on host. + # When validation/test models are run, weights are synced first. + + if self.lightning_module.trainer.training: + # Create model for training which will run training. + optimizer = self.lightning_module.trainer.optimizers[0] + model = poptorch.trainingModel(model=model, options=self._create_opts(training=True), optimizer=optimizer) + self.poptorch_models['train'] = model + for x in ('val', 'test', 'predict'): + model = poptorch.inferenceModel( + model=model, + options=self._create_opts(training=False), + ) + self.poptorch_models[x] = model + + @property + def replication_factor(self): + return len(self.parallel_devices) + + def _create_opts(self, training): + opts = poptorch.Options() + opts.deviceIterations(self.device_iterations) + opts.replicationFactor(self.replication_factor) + gradient_accumulation = self.lightning_module.trainer.accumulate_grad_batches if training else 1 + opts.Training.gradientAccumulation(gradient_accumulation) + opts.autoRoundNumIPUs(self.autoround_num_ipus) + + # todo (sean): unsure if this is necessary but to be safe. + if os.environ.get("PL_GLOBAL_SEED"): + opts.randomSeed(int(os.environ["PL_GLOBAL_SEED"])) + return opts + + def on_reset_train_dataloader(self, dataloader) -> Union[Iterable, DataLoader]: + return self.process_dataloader(dataloader) + + def on_reset_val_dataloader(self, dataloader) -> Union[Iterable, DataLoader]: + return self.process_dataloader(dataloader) + + def on_reset_test_dataloader(self, dataloader) -> Union[Iterable, DataLoader]: + return self.process_dataloader(dataloader) + + def on_reset_predict_dataloader(self, dataloader) -> Union[Iterable, DataLoader]: + return self.process_dataloader(dataloader) + + def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: + if isinstance(dataloader, CombinedLoader): + dataloader.loaders = apply_to_collection( + dataloader.loaders, + DataLoader, + self.process_dataloader, + ) + return dataloader + elif isinstance(dataloader, list): + dataloader = apply_to_collection(dataloader, DataLoader, self.process_dataloader) + return dataloader + if not isinstance(dataloader, poptorch.DataLoader): + dataloader = self._convert_to_poptorch_loader( + dataloader=dataloader, opts=self._create_opts(training=self.lightning_module.training) + ) + return dataloader + + def _convert_to_poptorch_loader(self, dataloader: Union[Iterable, DataLoader], + opts: 'poptorch.Options') -> Union[Iterable, DataLoader]: + skip_keys = ('sampler', 'batch_sampler', 'dataset_kind') + + attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")} + + params = set(inspect.signature(dataloader.__init__).parameters) + contains_dataset = True + + if type(dataloader) is not DataLoader: + contains_dataset = "dataset" in params + params.update(inspect.signature(DataLoader.__init__).parameters) + + dl_args = {name: attrs[name] for name in params if name in attrs and name not in skip_keys} + + multiprocessing_context = dataloader.multiprocessing_context + dl_args['multiprocessing_context'] = multiprocessing_context + if not contains_dataset: + dl_args.pop('dataset') + + dataloader = poptorch.DataLoader(**dl_args, options=opts) + dataloader.multiprocessing_context = multiprocessing_context + return dataloader + + def _handle_gradient_accumulation_steps(self): + """ + This functions overrides the trainer.accumulation_scheduler to generate + ``accumulate_grad_batches=1``. + Therefore, ``optimizer_step`` will be called on every batch, and the IPU will handle grad accumulation. + """ + self._original_accumulate_grad_batches = self.lightning_module.trainer.accumulate_grad_batches + if self._original_accumulate_grad_batches > 1: + # todo (tchaton) Add support for accumulate_grad_batches being a dictionary. + self.lightning_module.trainer.accumulation_scheduler = GradientAccumulationScheduler({0: 1}) + + def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int: + if self._original_accumulate_grad_batches > 1: + if total_batch_idx % self._original_accumulate_grad_batches == 0: + current_global_step += 1 + return current_global_step + return super().update_global_step(total_batch_idx, current_global_step) + + @property + def _n_replicate(self): + # Ensure we replicate values to have enough dimensions to split across devices + accumulate_grad_batches = self._original_accumulate_grad_batches + return self.replication_factor * self.device_iterations * accumulate_grad_batches + + def _prepare_input(self, args): + + def to_tuple(x): + return tuple(x) + + def to_tensor(x): + return torch.tensor(x).unsqueeze(0).repeat(self._n_replicate) + + args = apply_to_collection(args, dtype=list, function=to_tuple) + args = apply_to_collection(args, dtype=(int, float), function=to_tensor) + return args + + def training_step(self, *args, **kwargs): + args = self._prepare_input(args) + return self.poptorch_models['train'](*args, **kwargs) + + def validation_step(self, *args, **kwargs): + args = self._prepare_input(args) + return self.poptorch_models['val'](*args, **kwargs) + + def test_step(self, *args, **kwargs): + args = self._prepare_input(args) + return self.poptorch_models['test'](*args, **kwargs) + + def predict_step(self, *args, **kwargs): + args = self._prepare_input(args) + return self.poptorch_models['predict'](*args, **kwargs) + + def teardown(self) -> None: + for k, model in self.poptorch_models.items(): + model.destroy() + + def _compiled(self, model): + # Required to ensure we only attach compiled models, as they are compiled lazily. + return model._executable is not None + + def _detach_models(self): + """ + Detaches all stage specific models from IPU devices. + """ + for k, model in self.poptorch_models.items(): + if self._compiled(model) and model.isAttachedToDevice(): + model.detachFromDevice() + + def _load_model(self, stage): + """ + Loads the stage specific accelerator model onto device if compiled and not attached to IPU devices. + Args: + stage: The stage to load + """ + self._detach_models() + model = self.poptorch_models[stage] + if self._compiled(model) and not model.isAttachedToDevice(): + model.attachToDevice() + + def on_train_start(self): + self._load_model('train') + + def on_validation_start(self): + self._load_model('val') + + def on_test_start(self): + self._load_model('test') + + def on_predict_start(self): + self._load_model('predict') + + def on_train_end(self): + self._detach_models() + + def on_validation_end(self): + self._detach_models() + + def on_test_end(self): + self._detach_models() + + def on_predict_end(self): + self._detach_models() + + def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + # Updates optimizer stats if LR scheduler modified the optimizer state + optimizer = self.lightning_module.trainer.optimizers[0] + self.poptorch_models['train'].setOptimizer(optimizer) + + @property + def on_gpu(self) -> bool: + return False + + @property + def root_device(self) -> torch.device: + pass + + def model_to_device(self) -> None: + pass + + @property + def is_global_zero(self) -> bool: + return True + + def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: + return tensor + + def barrier(self, name: Optional[str] = None) -> None: + pass + + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + return tensor + + def broadcast(self, obj: object, src: int = 0) -> object: + return obj diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 4d692ec517d19..1d50a93b0b086 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -21,6 +21,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.gpu import GPUAccelerator +from pytorch_lightning.accelerators.ipu import IPUAccelerator from pytorch_lightning.accelerators.tpu import TPUAccelerator from pytorch_lightning.plugins import ( ApexMixedPrecisionPlugin, @@ -36,6 +37,8 @@ DoublePrecisionPlugin, FullyShardedNativeMixedPrecisionPlugin, HorovodPlugin, + IPUPlugin, + IPUPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin, ShardedNativeMixedPrecisionPlugin, @@ -57,6 +60,7 @@ from pytorch_lightning.utilities import ( _APEX_AVAILABLE, _HOROVOD_AVAILABLE, + _IPU_AVAILABLE, _NATIVE_AMP_AVAILABLE, _TPU_AVAILABLE, AMPType, @@ -79,6 +83,7 @@ def __init__( self, num_processes, tpu_cores, + ipu_cores, distributed_backend, auto_select_gpus, gpus, @@ -98,6 +103,7 @@ def __init__( self.num_processes = num_processes self.tpu_cores = device_parser.parse_tpu_cores(tpu_cores) + self.ipu_cores = ipu_cores self.distributed_backend = distributed_backend self.auto_select_gpus = auto_select_gpus self.gpus = gpus @@ -248,6 +254,10 @@ def on_cpu(self) -> bool: def on_tpu(self) -> bool: return self.tpu_cores is not None + @property + def on_ipu(self) -> bool: + return self.ipu_cores is not None + @property def tpu_id(self) -> Optional[int]: if self.on_tpu and isinstance(self.tpu_cores, list): @@ -323,13 +333,18 @@ def parallel_devices(self) -> List[Union[torch.device, int]]: # https://github.com/PyTorchLightning/pytorch-lightning/issues/3169 if isinstance(self.tpu_cores, int): devices = list(range(self.tpu_cores)) + elif self.on_ipu: + if isinstance(self.ipu_cores, int): + devices = list(range(self.ipu_cores)) else: devices = [torch.device("cpu")] * self.num_processes return devices @property def root_gpu(self) -> Optional[int]: - return self.accelerator.root_device.index if not isinstance(self.accelerator, TPUAccelerator) else None + return self.accelerator.root_device.index if not isinstance( + self.accelerator, (IPUAccelerator, TPUAccelerator) + ) else None @property def is_training_type_in_plugins(self) -> bool: @@ -353,6 +368,9 @@ def select_precision_plugin(self) -> PrecisionPlugin: # set precision type self.amp_type = AMPType.from_str(self.amp_type) + if self.on_ipu: + return IPUPrecisionPlugin(self.precision) + if self._distrib_type == DistributedType.DEEPSPEED or isinstance(self._training_type_plugin, DeepSpeedPlugin): return DeepSpeedPrecisionPlugin(self.precision) @@ -459,6 +477,8 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: plugin = HorovodPlugin(parallel_devices=self.parallel_devices) elif self.on_tpu and isinstance(self.tpu_cores, list): plugin = SingleTPUPlugin(self.tpu_id) + elif self.on_ipu: + plugin = IPUPlugin(parallel_devices=self.parallel_devices) else: single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids) plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu")) @@ -499,6 +519,8 @@ def select_accelerator(self) -> Accelerator: acc_cls = GPUAccelerator elif self.on_tpu: acc_cls = TPUAccelerator + elif self.on_ipu: + acc_cls = IPUAccelerator else: acc_cls = CPUAccelerator # as precision_plugin is dependent on training_type_plugin, make sure @@ -562,6 +584,8 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): self._device_type = DeviceType.TPU if isinstance(self.tpu_cores, int): self._distrib_type = DistributedType.TPU_SPAWN + elif self.distributed_backend == 'ipu': + self._device_type = DeviceType.IPU elif self.distributed_backend and self._distrib_type is None: self._distrib_type = DistributedType(self.distributed_backend) @@ -609,8 +633,11 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): ) rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self._device_type == DeviceType.GPU}') - num_cores = self.tpu_cores if self.tpu_cores is not None else 0 - rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores') + num_tpu_cores = self.tpu_cores if self.tpu_cores is not None else 0 + rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_tpu_cores} TPU cores') + + num_ipu_cores = self.ipu_cores if self.ipu_cores is not None else 0 + rank_zero_info(f'IPU available: {_IPU_AVAILABLE}, using: {num_ipu_cores} IPU cores') if torch.cuda.is_available() and self._device_type != DeviceType.GPU: rank_zero_warn( diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 11a8b02bd3b95..0bdf0f73d0e0a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -107,6 +107,7 @@ def __init__( gpus: Optional[Union[List[int], str, int]] = None, auto_select_gpus: bool = False, tpu_cores: Optional[Union[List[int], str, int]] = None, + ipu_cores: Optional[int] = None, log_gpu_memory: Optional[str] = None, progress_bar_refresh_rate: Optional[int] = None, overfit_batches: Union[int, float] = 0.0, @@ -323,8 +324,8 @@ def __init__( self.optimizer_connector = OptimizerConnector(self) self.accelerator_connector = AcceleratorConnector( - num_processes, tpu_cores, distributed_backend, auto_select_gpus, gpus, num_nodes, sync_batchnorm, benchmark, - replace_sampler_ddp, deterministic, precision, amp_backend, amp_level, plugins + num_processes, tpu_cores, ipu_cores, distributed_backend, auto_select_gpus, gpus, num_nodes, sync_batchnorm, + benchmark, replace_sampler_ddp, deterministic, precision, amp_backend, amp_level, plugins ) self.logger_connector = LoggerConnector(self, log_gpu_memory) self.model_connector = ModelConnector(self) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 6664be43bef88..613a5013d5198 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -43,10 +43,12 @@ _HOROVOD_AVAILABLE, _HYDRA_AVAILABLE, _HYDRA_EXPERIMENTAL_AVAILABLE, + _IPU_AVAILABLE, _IS_INTERACTIVE, _module_available, _NATIVE_AMP_AVAILABLE, _OMEGACONF_AVAILABLE, + _POPTORCH_AVAILABLE, _RPC_AVAILABLE, _TORCH_GREATER_EQUAL_1_5, _TORCH_GREATER_EQUAL_1_6, diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 98e10a9126a44..3cb4d24d126fa 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -97,6 +97,7 @@ class DeviceType(LightningEnum): """ CPU = 'CPU' GPU = 'GPU' + IPU = 'IPU' TPU = 'TPU' diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index f40d092f68e9f..2a51b01404821 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -85,6 +85,7 @@ def _compare_version(package: str, op, version) -> bool: _KINETO_AVAILABLE = _TORCH_GREATER_EQUAL_1_8_1 and torch.profiler.kineto_available() _NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") _OMEGACONF_AVAILABLE = _module_available("omegaconf") +_POPTORCH_AVAILABLE = _module_available('poptorch') _RPC_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.rpc') _TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != 'none']) _TORCHTEXT_AVAILABLE = _module_available("torchtext") @@ -96,3 +97,9 @@ def _compare_version(package: str, op, version) -> bool: from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: E402 _TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() + +if _POPTORCH_AVAILABLE: + import poptorch + _IPU_AVAILABLE = poptorch.ipuHardwareIsAvailable() +else: + _IPU_AVAILABLE = False diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py new file mode 100644 index 0000000000000..0d5a6e89bb331 --- /dev/null +++ b/tests/accelerators/test_ipu.py @@ -0,0 +1,263 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Optional + +import pytest +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader + +from pytorch_lightning import Callback, seed_everything, Trainer +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.plugins import IPUPlugin, IPUPrecisionPlugin +from tests.helpers.boring_model import BoringModel +from tests.helpers.datamodules import ClassifDataModule +from tests.helpers.datasets import SklearnDataset +from tests.helpers.runif import RunIf +from tests.helpers.simple_models import ClassificationModel + + +class IPUModel(BoringModel): + + def training_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return loss + + def validation_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return loss + + def test_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return loss + + def training_epoch_end(self, outputs) -> None: + pass + + def validation_epoch_end(self, outputs) -> None: + pass + + def test_epoch_end(self, outputs) -> None: + pass + + +class IPUClassificationModel(ClassificationModel): + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + acc = self.accuracy(logits, y) + return acc + + def test_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + acc = self.accuracy(logits, y) + return acc + + def accuracy(self, logits, y): + # todo (sean): currently IPU poptorch doesn't implicit convert bools to tensor + # hence we use an explicit calculation for accuracy here. Once fixed in poptorch + # we can use the accuracy metric. + acc = torch.sum(torch.eq(torch.argmax(logits, -1), y).to(torch.float32)) / len(y) + return acc + + def validation_epoch_end(self, outputs) -> None: + self.log('val_acc', torch.stack(outputs).mean()) + + def test_epoch_end(self, outputs) -> None: + self.log('test_acc', torch.stack(outputs).mean()) + + +@RunIf(ipu=True, special=True) +@pytest.mark.parametrize('ipu_cores', [1, 4]) +def test_all_stages(tmpdir, ipu_cores): + model = IPUModel() + trainer = Trainer(fast_dev_run=True, ipu_cores=ipu_cores) + trainer.fit(model) + trainer.validate(model) + trainer.test(model) + trainer.predict(model, model.val_dataloader()) + + +@RunIf(ipu=True, special=True) +@pytest.mark.parametrize('ipu_cores', [1, 4]) +def test_inference_only(tmpdir, ipu_cores): + model = IPUModel() + + trainer = Trainer(fast_dev_run=True, ipu_cores=ipu_cores) + trainer.validate(model) + trainer.test(model) + trainer.predict(model, model.val_dataloader()) + + +@RunIf(ipu=True, special=True) +def test_optimization(tmpdir): + seed_everything(42) + + # Override to drop last uneven batch, as IPU poptorch does not support uneven inputs. + class DataModule(ClassifDataModule): + + def train_dataloader(self): + return DataLoader( + SklearnDataset(self.x_train, self.y_train, self._x_type, self._y_type), + batch_size=self.batch_size, + drop_last=True + ) + + def val_dataloader(self): + return DataLoader( + SklearnDataset(self.x_valid, self.y_valid, self._x_type, self._y_type), + batch_size=self.batch_size, + drop_last=True + ) + + def test_dataloader(self): + return DataLoader( + SklearnDataset(self.x_test, self.y_test, self._x_type, self._y_type), + batch_size=self.batch_size, + drop_last=True + ) + + dm = DataModule(length=1024) + model = IPUClassificationModel() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + weights_summary=None, + deterministic=True, + ipu_cores=2, + ) + + # fit model + trainer.fit(model, dm) + assert trainer.state.finished, f"Training failed with {trainer.state}" + assert dm.trainer is not None + + # validate + result = trainer.validate(datamodule=dm) + assert dm.trainer is not None + assert result[0]['val_acc'] > 0.7 + + # test + result = trainer.test(datamodule=dm) + assert dm.trainer is not None + test_result = result[0]['test_acc'] + assert test_result > 0.6 + + # test saved model + model_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(model_path) + + model = IPUClassificationModel.load_from_checkpoint(model_path) + + trainer = Trainer(default_root_dir=tmpdir, deterministic=True) + + result = trainer.test(model, dm.test_dataloader()) + saved_result = result[0]['test_acc'] + assert saved_result > 0.6 and (saved_result == test_result) + + +@RunIf(ipu=True, special=True) +def test_mixed_precision(tmpdir): + + class TestCallback(Callback): + + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: + assert isinstance(trainer.accelerator.precision_plugin, IPUPrecisionPlugin) + assert trainer.accelerator.precision_plugin.precision == 16 + assert trainer.accelerator.model.precision == 16 + raise SystemExit + + model = IPUModel() + trainer = Trainer(fast_dev_run=True, ipu_cores=1, precision=16, callbacks=TestCallback()) + with pytest.raises(SystemExit): + trainer.fit(model) + + +@RunIf(ipu=True, special=True) +def test_pure_half_precision(tmpdir): + + class TestCallback(Callback): + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + assert isinstance(trainer.accelerator.training_type_plugin, IPUPlugin) + assert isinstance(trainer.accelerator.precision_plugin, IPUPrecisionPlugin) + assert trainer.accelerator.precision_plugin.precision == 16 + assert trainer.accelerator.model.precision == 16 + assert trainer.accelerator.training_type_plugin.convert_model_to_half + for param in trainer.accelerator.model.parameters(): + assert param.dtype == torch.float16 + raise SystemExit + + model = IPUModel() + trainer = Trainer( + fast_dev_run=True, + ipu_cores=1, + precision=16, + plugins=IPUPlugin(convert_model_to_half=True), + callbacks=TestCallback() + ) + with pytest.raises(SystemExit): + trainer.fit(model) + + +@RunIf(ipu=True, special=True) +def test_device_iterations_ipu_plugin(tmpdir): + + class TestCallback(Callback): + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + assert isinstance(trainer.accelerator.training_type_plugin, IPUPlugin) + assert trainer.accelerator.training_type_plugin.device_iterations == 20 + # assert device iterations has been set correctly within the poptorch options + poptorch_model = trainer.accelerator.training_type_plugin.poptorch_models['train'] + assert poptorch_model._options.toDict()['device_iterations'] == 20 + raise SystemExit + + model = IPUModel() + trainer = Trainer(fast_dev_run=True, ipu_cores=1, plugins=IPUPlugin(device_iterations=20), callbacks=TestCallback()) + with pytest.raises(SystemExit): + trainer.fit(model) + + +@RunIf(ipu=True, special=True) +def test_accumulated_batches(tmpdir): + + class TestCallback(Callback): + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + # ensure the accumulation_scheduler is overridden to accumulate every batch + # since ipu handle accumulation + assert trainer.accumulation_scheduler.scheduling == {0: 1} + # assert poptorch option have been set correctly + poptorch_model = trainer.accelerator.training_type_plugin.poptorch_models['train'] + assert poptorch_model._options.Training.toDict()['gradient_accumulation'] == 2 + raise SystemExit + + model = IPUModel() + trainer = Trainer(fast_dev_run=True, ipu_cores=1, accumulate_grad_batches=2, callbacks=TestCallback()) + with pytest.raises(SystemExit): + trainer.fit(model) diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index 630a341ec2d30..737ddd68dff17 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -27,6 +27,7 @@ _FAIRSCALE_FULLY_SHARDED_AVAILABLE, _FAIRSCALE_PIPE_AVAILABLE, _HOROVOD_AVAILABLE, + _IPU_AVAILABLE, _NATIVE_AMP_AVAILABLE, _RPC_AVAILABLE, _TORCH_QUANTIZE_AVAILABLE, @@ -63,6 +64,7 @@ def __new__( amp_apex: bool = False, amp_native: bool = False, tpu: bool = False, + ipu: bool = False, horovod: bool = False, horovod_nccl: bool = False, skip_windows: bool = False, @@ -85,6 +87,7 @@ def __new__( amp_apex: NVIDIA Apex is installed amp_native: if native PyTorch native AMP is supported tpu: if TPU is available + ipu: if IPU is available horovod: if Horovod is installed horovod_nccl: if Horovod is installed with NCCL support skip_windows: skip test for Windows platform (typically fo some limited torch functionality) @@ -139,6 +142,10 @@ def __new__( conditions.append(not _TPU_AVAILABLE) reasons.append("TPU") + if ipu: + conditions.append(not _IPU_AVAILABLE) + reasons.append("IPU") + if horovod: conditions.append(not _HOROVOD_AVAILABLE) reasons.append("Horovod")