diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 699f233ee2cc..1873560eed99 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -184,3 +184,8 @@ jobs: #train mkdir -p ~/.cache/torch/checkpoints/ && wget "https://download.pytorch.org/models/vgg16-397923af.pth" -O ~/.cache/torch/checkpoints/vgg16-397923af.pth python examples/fast_neural_style/neural_style.py train --epochs 1 --cuda 0 --dataset test --dataroot . --image_size 32 --style_image examples/fast_neural_style/images/style_images/mosaic.jpg --style_size 32 + - name: Run SR Example + if: ${{ matrix.os == 'ubuntu-latest' }} + run: | + # Super-Resolution + python examples/super_resolution/main.py --upscale_factor 3 --crop_size 180 --batch_size 4 --test_batch_size 100 --n_epochs 1 --lr 0.001 --threads 2 --debug diff --git a/examples/super_resolution/README.md b/examples/super_resolution/README.md index f9be6c92f563..d874747dc1cd 100644 --- a/examples/super_resolution/README.md +++ b/examples/super_resolution/README.md @@ -5,15 +5,16 @@ ported from [pytorch-examples](https://github.com/pytorch/examples/tree/main/sup This example illustrates how to use the efficient sub-pixel convolution layer described in ["Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network" - Shi et al.](https://arxiv.org/abs/1609.05158) for increasing spatial resolution within your network for tasks such as superresolution. ``` -usage: main.py [-h] --upscale_factor UPSCALE_FACTOR [--batch_size BATCHSIZE] +usage: main.py [-h] --upscale_factor UPSCALE_FACTOR [--crop_size CROPSIZE] [--batch_size BATCHSIZE] [--test_batch_size TESTBATCHSIZE] [--n_epochs NEPOCHS] [--lr LR] - [--cuda] [--threads THREADS] [--seed SEED] + [--cuda] [--threads THREADS] [--seed SEED] [--debug] PyTorch Super Res Example optional arguments: -h, --help show this help message and exit --upscale_factor super resolution upscale factor + --crop_size cropped size of the images for training --batch_size training batch size --test_batch_size testing batch size --n_epochs number of epochs to train for @@ -22,6 +23,7 @@ optional arguments: --mps enable GPU on macOS --threads number of threads for data loader to use Default=4 --seed random seed to use. Default=123 + --debug debug mode for testing ``` This example trains a super-resolution network on the [Caltech101 dataset](https://pytorch.org/vision/main/generated/torchvision.datasets.Caltech101.html). A snapshot of the model after every epoch with filename `model_epoch_.pth` @@ -30,8 +32,20 @@ This example trains a super-resolution network on the [Caltech101 dataset](https ### Train -`python main.py --upscale_factor 3 --batch_size 4 --test_batch_size 100 --n_epochs 30 --lr 0.001` +`python main.py --upscale_factor 3 --crop_size 180 --batch_size 4 --test_batch_size 100 --n_epochs 30 --lr 0.001` -### Super Resolve +### Super-Resolve `python super_resolve.py --input_image .jpg --model model_epoch_500.pth --output_filename out.png` + +### Super-resolve example on a Cifar-10 image + +#### Input Image +![Cifar input image](./images/input_cifar.png) + +#### Output Images +| Output image from Model | Output from bicubic sampling | +|-------------------------------|------------------------------------| +| ![Cifar output image](./images/out_cifar.png) | ![Cifar output from bicubic sampling](./images/bicubic_image_cifar.png)| + + diff --git a/examples/super_resolution/images/bicubic_image_cifar.png b/examples/super_resolution/images/bicubic_image_cifar.png new file mode 100644 index 000000000000..b5bd4d9cf1b4 Binary files /dev/null and b/examples/super_resolution/images/bicubic_image_cifar.png differ diff --git a/examples/super_resolution/images/input_cifar.png b/examples/super_resolution/images/input_cifar.png new file mode 100644 index 000000000000..217b7e67d385 Binary files /dev/null and b/examples/super_resolution/images/input_cifar.png differ diff --git a/examples/super_resolution/images/out_cifar.png b/examples/super_resolution/images/out_cifar.png new file mode 100644 index 000000000000..9517aae801e2 Binary files /dev/null and b/examples/super_resolution/images/out_cifar.png differ diff --git a/examples/super_resolution/main.py b/examples/super_resolution/main.py index d46deec1701c..816d1caea7f2 100644 --- a/examples/super_resolution/main.py +++ b/examples/super_resolution/main.py @@ -8,11 +8,15 @@ from torch.utils.data import DataLoader from torchvision.transforms.functional import center_crop, resize, to_tensor +from ignite.contrib.handlers import ProgressBar + from ignite.engine import Engine, Events +from ignite.handlers import BasicTimeProfiler from ignite.metrics import PSNR # Training settings parser = argparse.ArgumentParser(description="PyTorch Super Res Example") +parser.add_argument("--crop_size", type=int, default=256, help="cropped size of the images for training") parser.add_argument("--upscale_factor", type=int, required=True, help="super resolution upscale factor") parser.add_argument("--batch_size", type=int, default=64, help="training batch size") parser.add_argument("--test_batch_size", type=int, default=10, help="testing batch size") @@ -22,6 +26,8 @@ parser.add_argument("--mps", action="store_true", default=False, help="enables macOS GPU training") parser.add_argument("--threads", type=int, default=4, help="number of threads for data loader to use") parser.add_argument("--seed", type=int, default=123, help="random seed to use. Default=123") +parser.add_argument("--debug", action="store_true", help="use debug") + opt = parser.parse_args() print(opt) @@ -70,8 +76,8 @@ def __len__(self): trainset = torchvision.datasets.Caltech101(root="./data", download=True) testset = torchvision.datasets.Caltech101(root="./data", download=False) -trainset_sr = SRDataset(trainset, scale_factor=opt.upscale_factor) -testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor) +trainset_sr = SRDataset(trainset, scale_factor=opt.upscale_factor, crop_size=opt.crop_size) +testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor, crop_size=opt.crop_size) training_data_loader = DataLoader(dataset=trainset_sr, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True) testing_data_loader = DataLoader(dataset=testset_sr, num_workers=opt.threads, batch_size=opt.test_batch_size) @@ -109,35 +115,22 @@ def validation_step(engine, batch): psnr = PSNR(data_range=1) psnr.attach(evaluator, "psnr") validate_every = 1 -log_interval = 100 - -@trainer.on(Events.ITERATION_COMPLETED(every=log_interval)) -def log_training_loss(engine): - print( - "===> Epoch[{}]({}/{}): Loss: {:.4f}".format( - engine.state.epoch, engine.state.iteration, len(training_data_loader), engine.state.output - ) - ) +if opt.debug: + epoch_length = 10 + validate_epoch_length = 1 +else: + epoch_length = len(training_data_loader) + validate_epoch_length = len(testing_data_loader) @trainer.on(Events.EPOCH_COMPLETED(every=validate_every)) def log_validation(): - evaluator.run(testing_data_loader) + evaluator.run(testing_data_loader, epoch_length=validate_epoch_length) metrics = evaluator.state.metrics print(f"Epoch: {trainer.state.epoch}, Avg. PSNR: {metrics['psnr']} dB") -@trainer.on(Events.EPOCH_COMPLETED) -def log_epoch_time(): - print(f"Epoch {trainer.state.epoch}, Time Taken : {trainer.state.times['EPOCH_COMPLETED']}") - - -@trainer.on(Events.COMPLETED) -def log_total_time(): - print(f"Total Time: {trainer.state.times['COMPLETED']}") - - @trainer.on(Events.EPOCH_COMPLETED) def checkpoint(): model_out_path = "model_epoch_{}.pth".format(trainer.state.epoch) @@ -145,4 +138,13 @@ def checkpoint(): print("Checkpoint saved to {}".format(model_out_path)) -trainer.run(training_data_loader, opt.n_epochs) +# Attach basic profiler +basic_profiler = BasicTimeProfiler() +basic_profiler.attach(trainer) + +ProgressBar().attach(trainer, output_transform=lambda x: {"loss": x}) + +trainer.run(training_data_loader, opt.n_epochs, epoch_length=epoch_length) + +results = basic_profiler.get_results() +basic_profiler.print_results(results)