Skip to content

Add CI for Super Resolution example and tqdm bar to the example #2899

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
a7829a9
Add the example for Super-Resolution
guptaaryan16 Mar 3, 2023
74602d4
Merge branch 'master' of github.com:guptaaryan16/ignite
guptaaryan16 Mar 3, 2023
1b0baf3
Made some changes
guptaaryan16 Mar 3, 2023
7ebee49
Made some changes
guptaaryan16 Mar 6, 2023
f6b5b41
Merge branch 'pytorch:master' into master
guptaaryan16 Mar 14, 2023
d810510
Merge branch 'pytorch:master' into master
guptaaryan16 Mar 15, 2023
3982d7b
Add the time profiling features
guptaaryan16 Mar 15, 2023
bc219c7
Merge branch 'pytorch:master' into master
guptaaryan16 Mar 17, 2023
982a0eb
Added torchvision dataset
guptaaryan16 Mar 17, 2023
51fe3df
Merge branch 'master' of github.com:guptaaryan16/ignite
guptaaryan16 Mar 17, 2023
0cd5c59
Changed the dataset used in README to cifar10
guptaaryan16 Mar 17, 2023
83f10e2
Merge branch 'pytorch:master' into master
guptaaryan16 Mar 20, 2023
7bcea2f
Used snake case in arguments
guptaaryan16 Mar 20, 2023
698d76f
Made some changes
guptaaryan16 Mar 20, 2023
51f47b4
Make some formatting changes
guptaaryan16 Mar 20, 2023
235c908
Make the formatting changes
guptaaryan16 Mar 20, 2023
3b2fde9
some changes
guptaaryan16 Mar 20, 2023
0e2f9a3
update the crop method
guptaaryan16 Mar 21, 2023
3d9dda7
Made the suggested changes
guptaaryan16 Mar 21, 2023
a91912b
Merge branch 'pytorch:master' into master
guptaaryan16 Mar 22, 2023
689b7e4
Add SR example to unit tests
guptaaryan16 Mar 22, 2023
3303d86
Merge branch 'master' of github.com:guptaaryan16/ignite
guptaaryan16 Mar 22, 2023
fb3f64a
Add tqdm to the SR example and some CI changes
guptaaryan16 Mar 22, 2023
051999e
Update unit-tests.yml
guptaaryan16 Mar 22, 2023
e36beff
Update unit-tests.yml
guptaaryan16 Mar 22, 2023
87456cd
changed crop_size in SR example
guptaaryan16 Mar 22, 2023
780dbdb
Made crop_size a parameter in SR example
guptaaryan16 Mar 22, 2023
b69c914
Add debug mode in SR example
guptaaryan16 Mar 24, 2023
4b1d337
Added Cifar image example
guptaaryan16 Mar 24, 2023
93766f7
autopep8 fix
guptaaryan16 Mar 24, 2023
93d1584
Some reformatting of files
guptaaryan16 Mar 24, 2023
8541b2c
Merge branch 'master' of github.com:guptaaryan16/ignite
guptaaryan16 Mar 24, 2023
655e569
Added Basic Profile Handler in SR example
guptaaryan16 Mar 27, 2023
2ce8749
made some changes
guptaaryan16 Mar 27, 2023
52b3043
Merge branch 'pytorch:master' into master
guptaaryan16 Mar 27, 2023
b53b150
Merge branch 'master' of github.com:guptaaryan16/ignite
guptaaryan16 Mar 27, 2023
9f81e33
Update README
guptaaryan16 Mar 28, 2023
5ccd25d
Update README.md
guptaaryan16 Mar 29, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,8 @@ jobs:
#train
mkdir -p ~/.cache/torch/checkpoints/ && wget "https://download.pytorch.org/models/vgg16-397923af.pth" -O ~/.cache/torch/checkpoints/vgg16-397923af.pth
python examples/fast_neural_style/neural_style.py train --epochs 1 --cuda 0 --dataset test --dataroot . --image_size 32 --style_image examples/fast_neural_style/images/style_images/mosaic.jpg --style_size 32
- name: Run SR Example
if: ${{ matrix.os == 'ubuntu-latest' }}
run: |
# Super-Resolution
python examples/super_resolution/main.py --upscale_factor 3 --crop_size 180 --batch_size 4 --test_batch_size 100 --n_epochs 1 --lr 0.001 --threads 2 --debug
22 changes: 18 additions & 4 deletions examples/super_resolution/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ ported from [pytorch-examples](https://github.com/pytorch/examples/tree/main/sup
This example illustrates how to use the efficient sub-pixel convolution layer described in ["Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network" - Shi et al.](https://arxiv.org/abs/1609.05158) for increasing spatial resolution within your network for tasks such as superresolution.

```
usage: main.py [-h] --upscale_factor UPSCALE_FACTOR [--batch_size BATCHSIZE]
usage: main.py [-h] --upscale_factor UPSCALE_FACTOR [--crop_size CROPSIZE] [--batch_size BATCHSIZE]
[--test_batch_size TESTBATCHSIZE] [--n_epochs NEPOCHS] [--lr LR]
[--cuda] [--threads THREADS] [--seed SEED]
[--cuda] [--threads THREADS] [--seed SEED] [--debug]

PyTorch Super Res Example

optional arguments:
-h, --help show this help message and exit
--upscale_factor super resolution upscale factor
--crop_size cropped size of the images for training
--batch_size training batch size
--test_batch_size testing batch size
--n_epochs number of epochs to train for
Expand All @@ -22,6 +23,7 @@ optional arguments:
--mps enable GPU on macOS
--threads number of threads for data loader to use Default=4
--seed random seed to use. Default=123
--debug debug mode for testing
```

This example trains a super-resolution network on the [Caltech101 dataset](https://pytorch.org/vision/main/generated/torchvision.datasets.Caltech101.html). A snapshot of the model after every epoch with filename `model_epoch_<epoch_number>.pth`
Expand All @@ -30,8 +32,20 @@ This example trains a super-resolution network on the [Caltech101 dataset](https

### Train

`python main.py --upscale_factor 3 --batch_size 4 --test_batch_size 100 --n_epochs 30 --lr 0.001`
`python main.py --upscale_factor 3 --crop_size 180 --batch_size 4 --test_batch_size 100 --n_epochs 30 --lr 0.001`

### Super Resolve
### Super-Resolve

`python super_resolve.py --input_image <in>.jpg --model model_epoch_500.pth --output_filename out.png`

### Super-resolve example on a Cifar-10 image

#### Input Image
![Cifar input image](./images/input_cifar.png)

#### Output Images
| Output image from Model | Output from bicubic sampling |
|-------------------------------|------------------------------------|
| ![Cifar output image](./images/out_cifar.png) | ![Cifar output from bicubic sampling](./images/bicubic_image_cifar.png)|


Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/super_resolution/images/input_cifar.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/super_resolution/images/out_cifar.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
48 changes: 25 additions & 23 deletions examples/super_resolution/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
from torch.utils.data import DataLoader
from torchvision.transforms.functional import center_crop, resize, to_tensor

from ignite.contrib.handlers import ProgressBar

from ignite.engine import Engine, Events
from ignite.handlers import BasicTimeProfiler
from ignite.metrics import PSNR

# Training settings
parser = argparse.ArgumentParser(description="PyTorch Super Res Example")
parser.add_argument("--crop_size", type=int, default=256, help="cropped size of the images for training")
parser.add_argument("--upscale_factor", type=int, required=True, help="super resolution upscale factor")
parser.add_argument("--batch_size", type=int, default=64, help="training batch size")
parser.add_argument("--test_batch_size", type=int, default=10, help="testing batch size")
Expand All @@ -22,6 +26,8 @@
parser.add_argument("--mps", action="store_true", default=False, help="enables macOS GPU training")
parser.add_argument("--threads", type=int, default=4, help="number of threads for data loader to use")
parser.add_argument("--seed", type=int, default=123, help="random seed to use. Default=123")
parser.add_argument("--debug", action="store_true", help="use debug")

opt = parser.parse_args()

print(opt)
Expand Down Expand Up @@ -70,8 +76,8 @@ def __len__(self):
trainset = torchvision.datasets.Caltech101(root="./data", download=True)
testset = torchvision.datasets.Caltech101(root="./data", download=False)

trainset_sr = SRDataset(trainset, scale_factor=opt.upscale_factor)
testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor)
trainset_sr = SRDataset(trainset, scale_factor=opt.upscale_factor, crop_size=opt.crop_size)
testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor, crop_size=opt.crop_size)

training_data_loader = DataLoader(dataset=trainset_sr, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True)
testing_data_loader = DataLoader(dataset=testset_sr, num_workers=opt.threads, batch_size=opt.test_batch_size)
Expand Down Expand Up @@ -109,40 +115,36 @@ def validation_step(engine, batch):
psnr = PSNR(data_range=1)
psnr.attach(evaluator, "psnr")
validate_every = 1
log_interval = 100


@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
def log_training_loss(engine):
print(
"===> Epoch[{}]({}/{}): Loss: {:.4f}".format(
engine.state.epoch, engine.state.iteration, len(training_data_loader), engine.state.output
)
)
if opt.debug:
epoch_length = 10
validate_epoch_length = 1
else:
epoch_length = len(training_data_loader)
validate_epoch_length = len(testing_data_loader)


@trainer.on(Events.EPOCH_COMPLETED(every=validate_every))
def log_validation():
evaluator.run(testing_data_loader)
evaluator.run(testing_data_loader, epoch_length=validate_epoch_length)
metrics = evaluator.state.metrics
print(f"Epoch: {trainer.state.epoch}, Avg. PSNR: {metrics['psnr']} dB")


@trainer.on(Events.EPOCH_COMPLETED)
def log_epoch_time():
print(f"Epoch {trainer.state.epoch}, Time Taken : {trainer.state.times['EPOCH_COMPLETED']}")


@trainer.on(Events.COMPLETED)
def log_total_time():
print(f"Total Time: {trainer.state.times['COMPLETED']}")


@trainer.on(Events.EPOCH_COMPLETED)
def checkpoint():
model_out_path = "model_epoch_{}.pth".format(trainer.state.epoch)
torch.save(model, model_out_path)
print("Checkpoint saved to {}".format(model_out_path))


trainer.run(training_data_loader, opt.n_epochs)
# Attach basic profiler
basic_profiler = BasicTimeProfiler()
basic_profiler.attach(trainer)

ProgressBar().attach(trainer, output_transform=lambda x: {"loss": x})

trainer.run(training_data_loader, opt.n_epochs, epoch_length=epoch_length)

results = basic_profiler.get_results()
basic_profiler.print_results(results)