Skip to content

IPU Integration #7735

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 34 commits into from
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
88 changes: 88 additions & 0 deletions pl_examples/ipu_examples/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pprint import pprint

import torch
from torch.nn import functional as F

import pytorch_lightning as pl
from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule


class LitClassifier(pl.LightningModule):

def __init__(
self,
hidden_dim: int = 128,
learning_rate: float = 0.0001,
):
super().__init__()
self.save_hyperparameters()

self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)

def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
return x

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
acc = self.accuracy(logits, y)
return acc

def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
acc = self.accuracy(logits, y)
return acc

def accuracy(self, logits, y):
# todo (sean): currently IPU poptorch doesn't implicit convert bools to tensor
# hence we use an explicit calculation for accuracy here. Once fixed in poptorch
# we can use the accuracy metric.
acc = torch.sum(torch.eq(torch.argmax(logits, -1), y).to(torch.float32)) / len(y)
return acc

def validation_epoch_end(self, outputs) -> None:
self.log('val_acc', torch.stack(outputs).mean(), prog_bar=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to clear up why this is here, not in the step itself (the step functions are jitted, and the outputs are collated from all devices, so mean averaging etc cannot happen within the functions)


def test_epoch_end(self, outputs) -> None:
self.log('test_acc', torch.stack(outputs).mean())

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)


if __name__ == '__main__':
dm = MNISTDataModule(batch_size=32)

model = LitClassifier()

trainer = pl.Trainer(max_epochs=2, ipu_cores=8)

trainer.fit(model, datamodule=dm)

result = trainer.test(model, datamodule=dm)
pprint(result)
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
from pytorch_lightning.accelerators.accelerator import Accelerator # noqa F401
from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa F401
from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa F401
from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa F401
from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa F401
32 changes: 32 additions & 0 deletions pytorch_lightning/accelerators/ipu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import Callable

from torch.optim import Optimizer

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class IPUAccelerator(Accelerator):

def setup_optimizers(self, trainer):
super().setup_optimizers(trainer)

if len(self.optimizers) > 1:
raise MisconfigurationException("IPUs currently only support one optimizer.")

def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs):
# Optimizer step is handled by the IPU accelerator.
lambda_closure()
4 changes: 4 additions & 0 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401
FullyShardedNativeMixedPrecisionPlugin,
)
from pytorch_lightning.plugins.precision.ipu_precision import IPUPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
Expand All @@ -20,6 +21,7 @@
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.fully_sharded import DDPFullyShardedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.ipu import IPUPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin # noqa: F401
Expand All @@ -41,6 +43,8 @@
"DeepSpeedPrecisionPlugin",
"DoublePrecisionPlugin",
"HorovodPlugin",
"IPUPlugin",
"IPUPrecisionPlugin",
"NativeMixedPrecisionPlugin",
"PrecisionPlugin",
"ShardedNativeMixedPrecisionPlugin",
Expand Down
24 changes: 24 additions & 0 deletions pytorch_lightning/plugins/precision/ipu_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Any

from torch import Tensor

from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin


class IPUPrecisionPlugin(PrecisionPlugin):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awaelchli said before the next boilerplate precision plugin we should refactor to have the PrecisionPlugin inside of the TrainingTypePlugin, but maybe we can allow one more if this doesn't happen in time :P


def __init__(self, precision: int) -> None:
super().__init__()
self.precision = precision

def backward(
self,
closure_loss: Tensor,
*args: Any,
**kwargs: Any,
) -> Tensor:
# IPU internally manages bwd step.
return closure_loss

def clip_gradients(self, *args, **kwargs) -> None:
pass
Loading