-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathmodel_factory.py
104 lines (94 loc) · 4.27 KB
/
model_factory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107
from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
from torchvision.models.densenet import densenet121, densenet169, densenet161, densenet201
from torchvision.models.inception import inception_v3
import torchvision.transforms as transforms
from PIL import Image
def create_model(model_name, num_classes=1000, pretrained=False, **kwargs):
if 'test_time_pool' in kwargs:
test_time_pool = kwargs.pop('test_time_pool')
else:
test_time_pool = True
if model_name == 'dpn68':
model = dpn68(
pretrained=pretrained, test_time_pool=test_time_pool, num_classes=num_classes)
elif model_name == 'dpn68b':
model = dpn68b(
pretrained=pretrained, test_time_pool=test_time_pool, num_classes=num_classes)
elif model_name == 'dpn92':
model = dpn92(
pretrained=pretrained, test_time_pool=test_time_pool, num_classes=num_classes)
elif model_name == 'dpn98':
model = dpn98(
pretrained=pretrained, test_time_pool=test_time_pool, num_classes=num_classes)
elif model_name == 'dpn131':
model = dpn131(
pretrained=pretrained, test_time_pool=test_time_pool, num_classes=num_classes)
elif model_name == 'dpn107':
model = dpn107(
pretrained=pretrained, test_time_pool=test_time_pool, num_classes=num_classes)
elif model_name == 'resnet18':
model = resnet18(pretrained=pretrained, num_classes=num_classes, **kwargs)
elif model_name == 'resnet34':
model = resnet34(pretrained=pretrained, num_classes=num_classes, **kwargs)
elif model_name == 'resnet50':
model = resnet50(pretrained=pretrained, num_classes=num_classes, **kwargs)
elif model_name == 'resnet101':
model = resnet101(pretrained=pretrained, num_classes=num_classes, **kwargs)
elif model_name == 'resnet152':
model = resnet152(pretrained=pretrained, num_classes=num_classes, **kwargs)
elif model_name == 'densenet121':
model = densenet121(pretrained=pretrained, num_classes=num_classes, **kwargs)
elif model_name == 'densenet161':
model = densenet161(pretrained=pretrained, num_classes=num_classes, **kwargs)
elif model_name == 'densenet169':
model = densenet169(pretrained=pretrained, num_classes=num_classes, **kwargs)
elif model_name == 'densenet201':
model = densenet201(pretrained=pretrained, num_classes=num_classes, **kwargs)
elif model_name == 'inception_v3':
model = inception_v3(
pretrained=pretrained, num_classes=num_classes, transform_input=False, **kwargs)
else:
assert False, "Unknown model architecture (%s)" % model_name
return model
class LeNormalize(object):
"""Normalize to -1..1 in Google Inception style
"""
def __call__(self, tensor):
for t in tensor:
t.sub_(0.5).mul_(2.0)
return tensor
DEFAULT_CROP_PCT = 0.875
def get_transforms_eval(model_name, img_size=224, crop_pct=None):
crop_pct = crop_pct or DEFAULT_CROP_PCT
if 'dpn' in model_name:
if crop_pct is None:
# Use default 87.5% crop for model's native img_size
# but use 100% crop for larger than native as it
# improves test time results across all models.
if img_size == 224:
scale_size = int(math.floor(img_size / DEFAULT_CROP_PCT))
else:
scale_size = img_size
else:
scale_size = int(math.floor(img_size / crop_pct))
normalize = transforms.Normalize(
mean=[124 / 255, 117 / 255, 104 / 255],
std=[1 / (.0167 * 255)] * 3)
elif 'inception' in model_name:
scale_size = int(math.floor(img_size / crop_pct))
normalize = LeNormalize()
else:
scale_size = int(math.floor(img_size / crop_pct))
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
return transforms.Compose([
transforms.Resize(scale_size, Image.BICUBIC),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
normalize])