15
15
16
16
To run: python autoencoder.py --trainer.max_epochs=50
17
17
"""
18
+ from typing import Optional , Tuple
18
19
19
20
import torch
20
21
import torch .nn .functional as F
24
25
import pytorch_lightning as pl
25
26
from pl_examples import _DATASETS_PATH , cli_lightning_logo
26
27
from pl_examples .basic_examples .mnist_datamodule import MNIST
28
+ from pytorch_lightning .utilities import rank_zero_only
27
29
from pytorch_lightning .utilities .cli import LightningCLI
28
30
from pytorch_lightning .utilities .imports import _TORCHVISION_AVAILABLE
29
31
30
32
if _TORCHVISION_AVAILABLE :
33
+ import torchvision
31
34
from torchvision import transforms
35
+ from torchvision .utils import save_image
36
+
37
+
38
+ class ImageSampler (pl .callbacks .Callback ):
39
+ def __init__ (
40
+ self ,
41
+ num_samples : int = 3 ,
42
+ nrow : int = 8 ,
43
+ padding : int = 2 ,
44
+ normalize : bool = True ,
45
+ norm_range : Optional [Tuple [int , int ]] = None ,
46
+ scale_each : bool = False ,
47
+ pad_value : int = 0 ,
48
+ ) -> None :
49
+ """
50
+ Args:
51
+ num_samples: Number of images displayed in the grid. Default: ``3``.
52
+ nrow: Number of images displayed in each row of the grid.
53
+ The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
54
+ padding: Amount of padding. Default: ``2``.
55
+ normalize: If ``True``, shift the image to the range (0, 1),
56
+ by the min and max values specified by :attr:`range`. Default: ``False``.
57
+ norm_range: Tuple (min, max) where min and max are numbers,
58
+ then these numbers are used to normalize the image. By default, min and max
59
+ are computed from the tensor.
60
+ scale_each: If ``True``, scale each image in the batch of
61
+ images separately rather than the (min, max) over all images. Default: ``False``.
62
+ pad_value: Value for the padded pixels. Default: ``0``.
63
+ """
64
+ if not _TORCHVISION_AVAILABLE : # pragma: no cover
65
+ raise ModuleNotFoundError ("You want to use `torchvision` which is not installed yet." )
66
+
67
+ super ().__init__ ()
68
+ self .num_samples = num_samples
69
+ self .nrow = nrow
70
+ self .padding = padding
71
+ self .normalize = normalize
72
+ self .norm_range = norm_range
73
+ self .scale_each = scale_each
74
+ self .pad_value = pad_value
75
+
76
+ def _to_grid (self , images ):
77
+ return torchvision .utils .make_grid (
78
+ tensor = images ,
79
+ nrow = self .nrow ,
80
+ padding = self .padding ,
81
+ normalize = self .normalize ,
82
+ range = self .norm_range ,
83
+ scale_each = self .scale_each ,
84
+ pad_value = self .pad_value ,
85
+ )
86
+
87
+ @rank_zero_only
88
+ def on_epoch_end (self , trainer : pl .Trainer , pl_module : pl .LightningModule ) -> None :
89
+ if not _TORCHVISION_AVAILABLE :
90
+ return
91
+
92
+ images , _ = next (iter (DataLoader (trainer .datamodule .mnist_val , batch_size = self .num_samples )))
93
+ images_flattened = images .view (images .size (0 ), - 1 )
94
+
95
+ # generate images
96
+ with torch .no_grad ():
97
+ pl_module .eval ()
98
+ images_generated = pl_module (images_flattened .to (pl_module .device ))
99
+ pl_module .train ()
100
+
101
+ if trainer .current_epoch == 0 :
102
+ save_image (self ._to_grid (images ), f"grid_ori_{ trainer .current_epoch } .png" )
103
+ save_image (self ._to_grid (images_generated .reshape (images .shape )), f"grid_generated_{ trainer .current_epoch } .png" )
32
104
33
105
34
106
class LitAutoEncoder (pl .LightningModule ):
@@ -46,44 +118,37 @@ def __init__(self, hidden_dim: int = 64):
46
118
self .decoder = nn .Sequential (nn .Linear (3 , hidden_dim ), nn .ReLU (), nn .Linear (hidden_dim , 28 * 28 ))
47
119
48
120
def forward (self , x ):
49
- # in lightning, forward defines the prediction/inference actions
50
- embedding = self .encoder (x )
51
- return embedding
52
-
53
- def training_step (self , batch , batch_idx ):
54
- x , y = batch
55
- x = x .view (x .size (0 ), - 1 )
56
121
z = self .encoder (x )
57
122
x_hat = self .decoder (z )
58
- loss = F .mse_loss (x_hat , x )
59
- return loss
123
+ return x_hat
124
+
125
+ def training_step (self , batch , batch_idx ):
126
+ return self ._common_step (batch , batch_idx , "train" )
60
127
61
128
def validation_step (self , batch , batch_idx ):
62
- x , y = batch
63
- x = x .view (x .size (0 ), - 1 )
64
- z = self .encoder (x )
65
- x_hat = self .decoder (z )
66
- loss = F .mse_loss (x_hat , x )
67
- self .log ("valid_loss" , loss , on_step = True )
129
+ self ._common_step (batch , batch_idx , "val" )
68
130
69
131
def test_step (self , batch , batch_idx ):
70
- x , y = batch
71
- x = x .view (x .size (0 ), - 1 )
72
- z = self .encoder (x )
73
- x_hat = self .decoder (z )
74
- loss = F .mse_loss (x_hat , x )
75
- self .log ("test_loss" , loss , on_step = True )
132
+ self ._common_step (batch , batch_idx , "test" )
76
133
77
134
def predict_step (self , batch , batch_idx , dataloader_idx = None ):
78
- x , y = batch
79
- x = x .view (x .size (0 ), - 1 )
80
- z = self .encoder (x )
81
- return self .decoder (z )
135
+ x = self ._prepare_batch (batch )
136
+ return self (x )
82
137
83
138
def configure_optimizers (self ):
84
139
optimizer = torch .optim .Adam (self .parameters (), lr = 1e-3 )
85
140
return optimizer
86
141
142
+ def _prepare_batch (self , batch ):
143
+ x , _ = batch
144
+ return x .view (x .size (0 ), - 1 )
145
+
146
+ def _common_step (self , batch , batch_idx , stage : str ):
147
+ x = self ._prepare_batch (batch )
148
+ loss = F .mse_loss (x , self (x ))
149
+ self .log (f"{ stage } _loss" , loss , on_step = True )
150
+ return loss
151
+
87
152
88
153
class MyDataModule (pl .LightningDataModule ):
89
154
def __init__ (self , batch_size : int = 32 ):
@@ -108,7 +173,12 @@ def predict_dataloader(self):
108
173
109
174
def cli_main ():
110
175
cli = LightningCLI (
111
- LitAutoEncoder , MyDataModule , seed_everything_default = 1234 , save_config_overwrite = True , run = False
176
+ LitAutoEncoder ,
177
+ MyDataModule ,
178
+ seed_everything_default = 1234 ,
179
+ save_config_overwrite = True ,
180
+ run = False , # used to de-activate automatic fitting.
181
+ trainer_defaults = {"callbacks" : ImageSampler (), "max_epochs" : 10 },
112
182
)
113
183
cli .trainer .fit (cli .model , datamodule = cli .datamodule )
114
184
cli .trainer .test (ckpt_path = "best" )
0 commit comments