|
8 | 8 | from torch.utils.data import DataLoader
|
9 | 9 | from torchvision.transforms.functional import center_crop, resize, to_tensor
|
10 | 10 |
|
| 11 | +from ignite.contrib.handlers import ProgressBar |
| 12 | + |
11 | 13 | from ignite.engine import Engine, Events
|
| 14 | +from ignite.handlers import BasicTimeProfiler |
12 | 15 | from ignite.metrics import PSNR
|
13 | 16 |
|
14 | 17 | # Training settings
|
15 | 18 | parser = argparse.ArgumentParser(description="PyTorch Super Res Example")
|
| 19 | +parser.add_argument("--crop_size", type=int, default=256, help="cropped size of the images for training") |
16 | 20 | parser.add_argument("--upscale_factor", type=int, required=True, help="super resolution upscale factor")
|
17 | 21 | parser.add_argument("--batch_size", type=int, default=64, help="training batch size")
|
18 | 22 | parser.add_argument("--test_batch_size", type=int, default=10, help="testing batch size")
|
|
22 | 26 | parser.add_argument("--mps", action="store_true", default=False, help="enables macOS GPU training")
|
23 | 27 | parser.add_argument("--threads", type=int, default=4, help="number of threads for data loader to use")
|
24 | 28 | parser.add_argument("--seed", type=int, default=123, help="random seed to use. Default=123")
|
| 29 | +parser.add_argument("--debug", action="store_true", help="use debug") |
| 30 | + |
25 | 31 | opt = parser.parse_args()
|
26 | 32 |
|
27 | 33 | print(opt)
|
@@ -70,8 +76,8 @@ def __len__(self):
|
70 | 76 | trainset = torchvision.datasets.Caltech101(root="./data", download=True)
|
71 | 77 | testset = torchvision.datasets.Caltech101(root="./data", download=False)
|
72 | 78 |
|
73 |
| -trainset_sr = SRDataset(trainset, scale_factor=opt.upscale_factor) |
74 |
| -testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor) |
| 79 | +trainset_sr = SRDataset(trainset, scale_factor=opt.upscale_factor, crop_size=opt.crop_size) |
| 80 | +testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor, crop_size=opt.crop_size) |
75 | 81 |
|
76 | 82 | training_data_loader = DataLoader(dataset=trainset_sr, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True)
|
77 | 83 | testing_data_loader = DataLoader(dataset=testset_sr, num_workers=opt.threads, batch_size=opt.test_batch_size)
|
@@ -109,40 +115,36 @@ def validation_step(engine, batch):
|
109 | 115 | psnr = PSNR(data_range=1)
|
110 | 116 | psnr.attach(evaluator, "psnr")
|
111 | 117 | validate_every = 1
|
112 |
| -log_interval = 100 |
113 |
| - |
114 | 118 |
|
115 |
| -@trainer.on(Events.ITERATION_COMPLETED(every=log_interval)) |
116 |
| -def log_training_loss(engine): |
117 |
| - print( |
118 |
| - "===> Epoch[{}]({}/{}): Loss: {:.4f}".format( |
119 |
| - engine.state.epoch, engine.state.iteration, len(training_data_loader), engine.state.output |
120 |
| - ) |
121 |
| - ) |
| 119 | +if opt.debug: |
| 120 | + epoch_length = 10 |
| 121 | + validate_epoch_length = 1 |
| 122 | +else: |
| 123 | + epoch_length = len(training_data_loader) |
| 124 | + validate_epoch_length = len(testing_data_loader) |
122 | 125 |
|
123 | 126 |
|
124 | 127 | @trainer.on(Events.EPOCH_COMPLETED(every=validate_every))
|
125 | 128 | def log_validation():
|
126 |
| - evaluator.run(testing_data_loader) |
| 129 | + evaluator.run(testing_data_loader, epoch_length=validate_epoch_length) |
127 | 130 | metrics = evaluator.state.metrics
|
128 | 131 | print(f"Epoch: {trainer.state.epoch}, Avg. PSNR: {metrics['psnr']} dB")
|
129 | 132 |
|
130 | 133 |
|
131 |
| -@trainer.on(Events.EPOCH_COMPLETED) |
132 |
| -def log_epoch_time(): |
133 |
| - print(f"Epoch {trainer.state.epoch}, Time Taken : {trainer.state.times['EPOCH_COMPLETED']}") |
134 |
| - |
135 |
| - |
136 |
| -@trainer.on(Events.COMPLETED) |
137 |
| -def log_total_time(): |
138 |
| - print(f"Total Time: {trainer.state.times['COMPLETED']}") |
139 |
| - |
140 |
| - |
141 | 134 | @trainer.on(Events.EPOCH_COMPLETED)
|
142 | 135 | def checkpoint():
|
143 | 136 | model_out_path = "model_epoch_{}.pth".format(trainer.state.epoch)
|
144 | 137 | torch.save(model, model_out_path)
|
145 | 138 | print("Checkpoint saved to {}".format(model_out_path))
|
146 | 139 |
|
147 | 140 |
|
148 |
| -trainer.run(training_data_loader, opt.n_epochs) |
| 141 | +# Attach basic profiler |
| 142 | +basic_profiler = BasicTimeProfiler() |
| 143 | +basic_profiler.attach(trainer) |
| 144 | + |
| 145 | +ProgressBar().attach(trainer, output_transform=lambda x: {"loss": x}) |
| 146 | + |
| 147 | +trainer.run(training_data_loader, opt.n_epochs, epoch_length=epoch_length) |
| 148 | + |
| 149 | +results = basic_profiler.get_results() |
| 150 | +basic_profiler.print_results(results) |
0 commit comments