Skip to content

Commit 46e53e7

Browse files
authored
Add Efficient SR challenge code
1 parent 88c7cb9 commit 46e53e7

File tree

1 file changed

+164
-0
lines changed

1 file changed

+164
-0
lines changed

Diff for: main_challenge_sr.py

+164
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import os.path
2+
import logging
3+
import time
4+
from collections import OrderedDict
5+
import torch
6+
7+
from utils import utils_logger
8+
from utils import utils_image as util
9+
10+
11+
'''
12+
This code can help you to calculate:
13+
`FLOPs`, `#Params`, `Runtime`, `#Activations`, `#Conv2d`, and `Max Memory Allocated`.
14+
15+
For more information, please refer to ECCVW paper "AIM 2020 Challenge on Efficient Super-Resolution: Methods and Results".
16+
17+
# If you use this code, please consider the following citations:
18+
19+
@inproceedings{zhang2020aim,
20+
title={AIM 2020 Challenge on Efficient Super-Resolution: Methods and Results},
21+
author={Kai Zhang and Martin Danelljan and Yawei Li and Radu Timofte and others},
22+
booktitle={European Conference on Computer Vision Workshops},
23+
year={2020}
24+
}
25+
@inproceedings{zhang2019aim,
26+
title={AIM 2019 Challenge on Constrained Super-Resolution: Methods and Results},
27+
author={Kai Zhang and Shuhang Gu and Radu Timofte and others},
28+
booktitle={IEEE International Conference on Computer Vision Workshops},
29+
year={2019}
30+
}
31+
32+
CuDNN (https://developer.nvidia.com/rdp/cudnn-archive) should be installed.
33+
34+
For `Max Memery` and `Runtime`, set 'print_modelsummary = False' and 'save_results = False'.
35+
'''
36+
37+
38+
39+
40+
def main():
41+
42+
utils_logger.logger_info('efficientsr_challenge', log_path='efficientsr_challenge.log')
43+
logger = logging.getLogger('efficientsr_challenge')
44+
45+
# print(torch.__version__) # pytorch version
46+
# print(torch.version.cuda) # cuda version
47+
# print(torch.backends.cudnn.version()) # cudnn version
48+
49+
# --------------------------------
50+
# basic settings
51+
# --------------------------------
52+
model_names = ['msrresnet', 'imdn']
53+
model_id = 1 # set the model name
54+
model_name = model_names[model_id]
55+
logger.info('{:>16s} : {:s}'.format('Model Name', model_name))
56+
57+
testsets = 'testsets' # set path of testsets
58+
testset_L = 'DIV2K_valid_LR' # set current testing dataset; 'DIV2K_test_LR'
59+
testset_L = 'set12'
60+
61+
save_results = True
62+
print_modelsummary = True # set False when calculating `Max Memery` and `Runtime`
63+
64+
torch.cuda.set_device(0) # set GPU ID
65+
logger.info('{:>16s} : {:<d}'.format('GPU ID', torch.cuda.current_device()))
66+
torch.cuda.empty_cache()
67+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
68+
69+
# --------------------------------
70+
# define network and load model
71+
# --------------------------------
72+
if model_name == 'msrresnet':
73+
from models.network_msrresnet import MSRResNet1 as net
74+
model = net(in_nc=3, out_nc=3, nc=64, nb=16, upscale=4) # define network
75+
model_path = os.path.join('model_zoo', 'msrresnet_x4_psnr.pth') # set model path
76+
elif model_name == 'imdn':
77+
from models.network_imdn import IMDN as net
78+
model = net(in_nc=3, out_nc=3, nc=64, nb=8, upscale=4, act_mode='L', upsample_mode='pixelshuffle') # define network
79+
model_path = os.path.join('model_zoo', 'imdn_x4.pth') # set model path
80+
81+
model.load_state_dict(torch.load(model_path), strict=True)
82+
model.eval()
83+
for k, v in model.named_parameters():
84+
v.requires_grad = False
85+
model = model.to(device)
86+
87+
# --------------------------------
88+
# print model summary
89+
# --------------------------------
90+
if print_modelsummary:
91+
from utils.utils_modelsummary import get_model_activation, get_model_flops
92+
input_dim = (3, 256, 256) # set the input dimension
93+
94+
activations, num_conv2d = get_model_activation(model, input_dim)
95+
logger.info('{:>16s} : {:<.4f} [M]'.format('#Activations', activations/10**6))
96+
logger.info('{:>16s} : {:<d}'.format('#Conv2d', num_conv2d))
97+
98+
flops = get_model_flops(model, input_dim, False)
99+
logger.info('{:>16s} : {:<.4f} [G]'.format('FLOPs', flops/10**9))
100+
101+
num_parameters = sum(map(lambda x: x.numel(), model.parameters()))
102+
logger.info('{:>16s} : {:<.4f} [M]'.format('#Params', num_parameters/10**6))
103+
104+
# --------------------------------
105+
# read image
106+
# --------------------------------
107+
L_path = os.path.join(testsets, testset_L)
108+
E_path = os.path.join(testsets, testset_L+'_'+model_name)
109+
util.mkdir(E_path)
110+
111+
# record runtime
112+
test_results = OrderedDict()
113+
test_results['runtime'] = []
114+
115+
logger.info('{:>16s} : {:s}'.format('Input Path', L_path))
116+
logger.info('{:>16s} : {:s}'.format('Output Path', E_path))
117+
idx = 0
118+
119+
start = torch.cuda.Event(enable_timing=True)
120+
end = torch.cuda.Event(enable_timing=True)
121+
122+
for img in util.get_image_paths(L_path):
123+
124+
# --------------------------------
125+
# (1) img_L
126+
# --------------------------------
127+
idx += 1
128+
img_name, ext = os.path.splitext(os.path.basename(img))
129+
logger.info('{:->4d}--> {:>10s}'.format(idx, img_name+ext))
130+
131+
img_L = util.imread_uint(img, n_channels=3)
132+
img_L = util.uint2tensor4(img_L)
133+
torch.cuda.empty_cache()
134+
img_L = img_L.to(device)
135+
136+
start.record()
137+
img_E = model(img_L)
138+
# logger.info('{:>16s} : {:<.3f} [M]'.format('Max Memery', torch.cuda.max_memory_allocated(torch.cuda.current_device())/1024**2)) # Memery
139+
end.record()
140+
torch.cuda.synchronize()
141+
test_results['runtime'].append(start.elapsed_time(end)) # milliseconds
142+
143+
144+
# torch.cuda.synchronize()
145+
# start = time.time()
146+
# img_E = model(img_L)
147+
# torch.cuda.synchronize()
148+
# end = time.time()
149+
# test_results['runtime'].append(end-start) # seconds
150+
151+
# --------------------------------
152+
# (2) img_E
153+
# --------------------------------
154+
img_E = util.tensor2uint(img_E)
155+
156+
if save_results:
157+
util.imsave(img_E, os.path.join(E_path, img_name+ext))
158+
ave_runtime = sum(test_results['runtime']) / len(test_results['runtime']) / 1000.0
159+
logger.info('------> Average runtime of ({}) is : {:.6f} seconds'.format(L_path, ave_runtime))
160+
161+
162+
if __name__ == '__main__':
163+
164+
main()

0 commit comments

Comments
 (0)