Skip to content

Commit aa81057

Browse files
authored
Add CI for Super Resolution example and tqdm bar to the example (#2899)
* Add the example for Super-Resolution * Made some changes * Made some changes * Add the time profiling features * Added torchvision dataset * Changed the dataset used in README to cifar10 * Used snake case in arguments * Made some changes * Make some formatting changes * Make the formatting changes * some changes * update the crop method * Made the suggested changes * Add SR example to unit tests * Add tqdm to the SR example and some CI changes * Update unit-tests.yml * Update unit-tests.yml * changed crop_size in SR example * Made crop_size a parameter in SR example * Add debug mode in SR example * Added Cifar image example * autopep8 fix * Some reformatting of files * Added Basic Profile Handler in SR example * made some changes * Update README * Update README.md --------- Co-authored-by: guptaaryan16 <[email protected]>
1 parent b48825b commit aa81057

File tree

6 files changed

+48
-27
lines changed

6 files changed

+48
-27
lines changed

.github/workflows/unit-tests.yml

+5
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,8 @@ jobs:
184184
#train
185185
mkdir -p ~/.cache/torch/checkpoints/ && wget "https://download.pytorch.org/models/vgg16-397923af.pth" -O ~/.cache/torch/checkpoints/vgg16-397923af.pth
186186
python examples/fast_neural_style/neural_style.py train --epochs 1 --cuda 0 --dataset test --dataroot . --image_size 32 --style_image examples/fast_neural_style/images/style_images/mosaic.jpg --style_size 32
187+
- name: Run SR Example
188+
if: ${{ matrix.os == 'ubuntu-latest' }}
189+
run: |
190+
# Super-Resolution
191+
python examples/super_resolution/main.py --upscale_factor 3 --crop_size 180 --batch_size 4 --test_batch_size 100 --n_epochs 1 --lr 0.001 --threads 2 --debug

examples/super_resolution/README.md

+18-4
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@ ported from [pytorch-examples](https://github.com/pytorch/examples/tree/main/sup
55
This example illustrates how to use the efficient sub-pixel convolution layer described in ["Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network" - Shi et al.](https://arxiv.org/abs/1609.05158) for increasing spatial resolution within your network for tasks such as superresolution.
66

77
```
8-
usage: main.py [-h] --upscale_factor UPSCALE_FACTOR [--batch_size BATCHSIZE]
8+
usage: main.py [-h] --upscale_factor UPSCALE_FACTOR [--crop_size CROPSIZE] [--batch_size BATCHSIZE]
99
[--test_batch_size TESTBATCHSIZE] [--n_epochs NEPOCHS] [--lr LR]
10-
[--cuda] [--threads THREADS] [--seed SEED]
10+
[--cuda] [--threads THREADS] [--seed SEED] [--debug]
1111
1212
PyTorch Super Res Example
1313
1414
optional arguments:
1515
-h, --help show this help message and exit
1616
--upscale_factor super resolution upscale factor
17+
--crop_size cropped size of the images for training
1718
--batch_size training batch size
1819
--test_batch_size testing batch size
1920
--n_epochs number of epochs to train for
@@ -22,6 +23,7 @@ optional arguments:
2223
--mps enable GPU on macOS
2324
--threads number of threads for data loader to use Default=4
2425
--seed random seed to use. Default=123
26+
--debug debug mode for testing
2527
```
2628

2729
This example trains a super-resolution network on the [Caltech101 dataset](https://pytorch.org/vision/main/generated/torchvision.datasets.Caltech101.html). A snapshot of the model after every epoch with filename `model_epoch_<epoch_number>.pth`
@@ -30,8 +32,20 @@ This example trains a super-resolution network on the [Caltech101 dataset](https
3032

3133
### Train
3234

33-
`python main.py --upscale_factor 3 --batch_size 4 --test_batch_size 100 --n_epochs 30 --lr 0.001`
35+
`python main.py --upscale_factor 3 --crop_size 180 --batch_size 4 --test_batch_size 100 --n_epochs 30 --lr 0.001`
3436

35-
### Super Resolve
37+
### Super-Resolve
3638

3739
`python super_resolve.py --input_image <in>.jpg --model model_epoch_500.pth --output_filename out.png`
40+
41+
### Super-resolve example on a Cifar-10 image
42+
43+
#### Input Image
44+
![Cifar input image](./images/input_cifar.png)
45+
46+
#### Output Images
47+
| Output image from Model | Output from bicubic sampling |
48+
|-------------------------------|------------------------------------|
49+
| ![Cifar output image](./images/out_cifar.png) | ![Cifar output from bicubic sampling](./images/bicubic_image_cifar.png)|
50+
51+
Loading
2.03 KB
Loading
13.6 KB
Loading

examples/super_resolution/main.py

+25-23
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@
88
from torch.utils.data import DataLoader
99
from torchvision.transforms.functional import center_crop, resize, to_tensor
1010

11+
from ignite.contrib.handlers import ProgressBar
12+
1113
from ignite.engine import Engine, Events
14+
from ignite.handlers import BasicTimeProfiler
1215
from ignite.metrics import PSNR
1316

1417
# Training settings
1518
parser = argparse.ArgumentParser(description="PyTorch Super Res Example")
19+
parser.add_argument("--crop_size", type=int, default=256, help="cropped size of the images for training")
1620
parser.add_argument("--upscale_factor", type=int, required=True, help="super resolution upscale factor")
1721
parser.add_argument("--batch_size", type=int, default=64, help="training batch size")
1822
parser.add_argument("--test_batch_size", type=int, default=10, help="testing batch size")
@@ -22,6 +26,8 @@
2226
parser.add_argument("--mps", action="store_true", default=False, help="enables macOS GPU training")
2327
parser.add_argument("--threads", type=int, default=4, help="number of threads for data loader to use")
2428
parser.add_argument("--seed", type=int, default=123, help="random seed to use. Default=123")
29+
parser.add_argument("--debug", action="store_true", help="use debug")
30+
2531
opt = parser.parse_args()
2632

2733
print(opt)
@@ -70,8 +76,8 @@ def __len__(self):
7076
trainset = torchvision.datasets.Caltech101(root="./data", download=True)
7177
testset = torchvision.datasets.Caltech101(root="./data", download=False)
7278

73-
trainset_sr = SRDataset(trainset, scale_factor=opt.upscale_factor)
74-
testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor)
79+
trainset_sr = SRDataset(trainset, scale_factor=opt.upscale_factor, crop_size=opt.crop_size)
80+
testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor, crop_size=opt.crop_size)
7581

7682
training_data_loader = DataLoader(dataset=trainset_sr, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True)
7783
testing_data_loader = DataLoader(dataset=testset_sr, num_workers=opt.threads, batch_size=opt.test_batch_size)
@@ -109,40 +115,36 @@ def validation_step(engine, batch):
109115
psnr = PSNR(data_range=1)
110116
psnr.attach(evaluator, "psnr")
111117
validate_every = 1
112-
log_interval = 100
113-
114118

115-
@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
116-
def log_training_loss(engine):
117-
print(
118-
"===> Epoch[{}]({}/{}): Loss: {:.4f}".format(
119-
engine.state.epoch, engine.state.iteration, len(training_data_loader), engine.state.output
120-
)
121-
)
119+
if opt.debug:
120+
epoch_length = 10
121+
validate_epoch_length = 1
122+
else:
123+
epoch_length = len(training_data_loader)
124+
validate_epoch_length = len(testing_data_loader)
122125

123126

124127
@trainer.on(Events.EPOCH_COMPLETED(every=validate_every))
125128
def log_validation():
126-
evaluator.run(testing_data_loader)
129+
evaluator.run(testing_data_loader, epoch_length=validate_epoch_length)
127130
metrics = evaluator.state.metrics
128131
print(f"Epoch: {trainer.state.epoch}, Avg. PSNR: {metrics['psnr']} dB")
129132

130133

131-
@trainer.on(Events.EPOCH_COMPLETED)
132-
def log_epoch_time():
133-
print(f"Epoch {trainer.state.epoch}, Time Taken : {trainer.state.times['EPOCH_COMPLETED']}")
134-
135-
136-
@trainer.on(Events.COMPLETED)
137-
def log_total_time():
138-
print(f"Total Time: {trainer.state.times['COMPLETED']}")
139-
140-
141134
@trainer.on(Events.EPOCH_COMPLETED)
142135
def checkpoint():
143136
model_out_path = "model_epoch_{}.pth".format(trainer.state.epoch)
144137
torch.save(model, model_out_path)
145138
print("Checkpoint saved to {}".format(model_out_path))
146139

147140

148-
trainer.run(training_data_loader, opt.n_epochs)
141+
# Attach basic profiler
142+
basic_profiler = BasicTimeProfiler()
143+
basic_profiler.attach(trainer)
144+
145+
ProgressBar().attach(trainer, output_transform=lambda x: {"loss": x})
146+
147+
trainer.run(training_data_loader, opt.n_epochs, epoch_length=epoch_length)
148+
149+
results = basic_profiler.get_results()
150+
basic_profiler.print_results(results)

0 commit comments

Comments
 (0)