Skip to content

Commit ffd275f

Browse files
authored
[Refactor] Improve auto-encoder example (#9402)
1 parent 81687aa commit ffd275f

File tree

1 file changed

+96
-26
lines changed

1 file changed

+96
-26
lines changed

pl_examples/basic_examples/autoencoder.py

Lines changed: 96 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
1616
To run: python autoencoder.py --trainer.max_epochs=50
1717
"""
18+
from typing import Optional, Tuple
1819

1920
import torch
2021
import torch.nn.functional as F
@@ -24,11 +25,82 @@
2425
import pytorch_lightning as pl
2526
from pl_examples import _DATASETS_PATH, cli_lightning_logo
2627
from pl_examples.basic_examples.mnist_datamodule import MNIST
28+
from pytorch_lightning.utilities import rank_zero_only
2729
from pytorch_lightning.utilities.cli import LightningCLI
2830
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
2931

3032
if _TORCHVISION_AVAILABLE:
33+
import torchvision
3134
from torchvision import transforms
35+
from torchvision.utils import save_image
36+
37+
38+
class ImageSampler(pl.callbacks.Callback):
39+
def __init__(
40+
self,
41+
num_samples: int = 3,
42+
nrow: int = 8,
43+
padding: int = 2,
44+
normalize: bool = True,
45+
norm_range: Optional[Tuple[int, int]] = None,
46+
scale_each: bool = False,
47+
pad_value: int = 0,
48+
) -> None:
49+
"""
50+
Args:
51+
num_samples: Number of images displayed in the grid. Default: ``3``.
52+
nrow: Number of images displayed in each row of the grid.
53+
The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
54+
padding: Amount of padding. Default: ``2``.
55+
normalize: If ``True``, shift the image to the range (0, 1),
56+
by the min and max values specified by :attr:`range`. Default: ``False``.
57+
norm_range: Tuple (min, max) where min and max are numbers,
58+
then these numbers are used to normalize the image. By default, min and max
59+
are computed from the tensor.
60+
scale_each: If ``True``, scale each image in the batch of
61+
images separately rather than the (min, max) over all images. Default: ``False``.
62+
pad_value: Value for the padded pixels. Default: ``0``.
63+
"""
64+
if not _TORCHVISION_AVAILABLE: # pragma: no cover
65+
raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.")
66+
67+
super().__init__()
68+
self.num_samples = num_samples
69+
self.nrow = nrow
70+
self.padding = padding
71+
self.normalize = normalize
72+
self.norm_range = norm_range
73+
self.scale_each = scale_each
74+
self.pad_value = pad_value
75+
76+
def _to_grid(self, images):
77+
return torchvision.utils.make_grid(
78+
tensor=images,
79+
nrow=self.nrow,
80+
padding=self.padding,
81+
normalize=self.normalize,
82+
range=self.norm_range,
83+
scale_each=self.scale_each,
84+
pad_value=self.pad_value,
85+
)
86+
87+
@rank_zero_only
88+
def on_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
89+
if not _TORCHVISION_AVAILABLE:
90+
return
91+
92+
images, _ = next(iter(DataLoader(trainer.datamodule.mnist_val, batch_size=self.num_samples)))
93+
images_flattened = images.view(images.size(0), -1)
94+
95+
# generate images
96+
with torch.no_grad():
97+
pl_module.eval()
98+
images_generated = pl_module(images_flattened.to(pl_module.device))
99+
pl_module.train()
100+
101+
if trainer.current_epoch == 0:
102+
save_image(self._to_grid(images), f"grid_ori_{trainer.current_epoch}.png")
103+
save_image(self._to_grid(images_generated.reshape(images.shape)), f"grid_generated_{trainer.current_epoch}.png")
32104

33105

34106
class LitAutoEncoder(pl.LightningModule):
@@ -46,44 +118,37 @@ def __init__(self, hidden_dim: int = 64):
46118
self.decoder = nn.Sequential(nn.Linear(3, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 28 * 28))
47119

48120
def forward(self, x):
49-
# in lightning, forward defines the prediction/inference actions
50-
embedding = self.encoder(x)
51-
return embedding
52-
53-
def training_step(self, batch, batch_idx):
54-
x, y = batch
55-
x = x.view(x.size(0), -1)
56121
z = self.encoder(x)
57122
x_hat = self.decoder(z)
58-
loss = F.mse_loss(x_hat, x)
59-
return loss
123+
return x_hat
124+
125+
def training_step(self, batch, batch_idx):
126+
return self._common_step(batch, batch_idx, "train")
60127

61128
def validation_step(self, batch, batch_idx):
62-
x, y = batch
63-
x = x.view(x.size(0), -1)
64-
z = self.encoder(x)
65-
x_hat = self.decoder(z)
66-
loss = F.mse_loss(x_hat, x)
67-
self.log("valid_loss", loss, on_step=True)
129+
self._common_step(batch, batch_idx, "val")
68130

69131
def test_step(self, batch, batch_idx):
70-
x, y = batch
71-
x = x.view(x.size(0), -1)
72-
z = self.encoder(x)
73-
x_hat = self.decoder(z)
74-
loss = F.mse_loss(x_hat, x)
75-
self.log("test_loss", loss, on_step=True)
132+
self._common_step(batch, batch_idx, "test")
76133

77134
def predict_step(self, batch, batch_idx, dataloader_idx=None):
78-
x, y = batch
79-
x = x.view(x.size(0), -1)
80-
z = self.encoder(x)
81-
return self.decoder(z)
135+
x = self._prepare_batch(batch)
136+
return self(x)
82137

83138
def configure_optimizers(self):
84139
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
85140
return optimizer
86141

142+
def _prepare_batch(self, batch):
143+
x, _ = batch
144+
return x.view(x.size(0), -1)
145+
146+
def _common_step(self, batch, batch_idx, stage: str):
147+
x = self._prepare_batch(batch)
148+
loss = F.mse_loss(x, self(x))
149+
self.log(f"{stage}_loss", loss, on_step=True)
150+
return loss
151+
87152

88153
class MyDataModule(pl.LightningDataModule):
89154
def __init__(self, batch_size: int = 32):
@@ -108,7 +173,12 @@ def predict_dataloader(self):
108173

109174
def cli_main():
110175
cli = LightningCLI(
111-
LitAutoEncoder, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False
176+
LitAutoEncoder,
177+
MyDataModule,
178+
seed_everything_default=1234,
179+
save_config_overwrite=True,
180+
run=False, # used to de-activate automatic fitting.
181+
trainer_defaults={"callbacks": ImageSampler(), "max_epochs": 10},
112182
)
113183
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
114184
cli.trainer.test(ckpt_path="best")

0 commit comments

Comments
 (0)