2
2
import torchvision .models as models
3
3
4
4
models = {
5
- "alexnet" : models .alexnet (pretrained = True ),
6
- "vgg16" : models .vgg16 (pretrained = True ),
7
- "squeezenet" : models .squeezenet1_0 (pretrained = True ),
8
- "densenet" : models .densenet161 (pretrained = True ),
9
- "inception_v3" : models .inception_v3 (pretrained = True ),
5
+ "alexnet" : {
6
+ "model" : models .alexnet (pretrained = True ),
7
+ "path" : "both"
8
+ },
9
+ "vgg16" : {
10
+ "model" : models .vgg16 (pretrained = True ),
11
+ "path" : "both"
12
+ },
13
+ "squeezenet" : {
14
+ "model" : models .squeezenet1_0 (pretrained = True ),
15
+ "path" : "both"
16
+ },
17
+ "densenet" : {
18
+ "model" : models .densenet161 (pretrained = True ),
19
+ "path" : "both"
20
+ },
21
+ "inception_v3" : {
22
+ "model" : models .inception_v3 (pretrained = True ),
23
+ "path" : "both"
24
+ },
10
25
#"googlenet": models.googlenet(pretrained=True),
11
- "shufflenet" : models .shufflenet_v2_x1_0 (pretrained = True ),
12
- "mobilenet_v2" : models .mobilenet_v2 (pretrained = True ),
13
- "resnext50_32x4d" : models .resnext50_32x4d (pretrained = True ),
14
- "wideresnet50_2" : models .wide_resnet50_2 (pretrained = True ),
15
- "mnasnet" : models .mnasnet1_0 (pretrained = True ),
16
- "resnet18" : torch .hub .load ('pytorch/vision:v0.5.0' , 'resnet18' , pretrained = True ),
17
- "resnet50" : torch .hub .load ('pytorch/vision:v0.5.0' , 'resnet50' , pretrained = True )}
26
+ "shufflenet" : {
27
+ "model" : models .shufflenet_v2_x1_0 (pretrained = True ),
28
+ "path" : "both"
29
+ },
30
+ "mobilenet_v2" : {
31
+ "model" : models .mobilenet_v2 (pretrained = True ),
32
+ "path" : "both"
33
+ },
34
+ "resnext50_32x4d" : {
35
+ "model" : models .resnext50_32x4d (pretrained = True ),
36
+ "path" : "both"
37
+ },
38
+ "wideresnet50_2" : {
39
+ "model" : models .wide_resnet50_2 (pretrained = True ),
40
+ "path" : "both"
41
+ },
42
+ "mnasnet" : {
43
+ "model" : models .mnasnet1_0 (pretrained = True ),
44
+ "path" : "both"
45
+ },
46
+ "resnet18" : {
47
+ "model" : torch .hub .load ('pytorch/vision:v0.6.0' , 'resnet18' , pretrained = True ),
48
+ "path" : "both"
49
+ },
50
+ "resnet50" : {
51
+ "model" :torch .hub .load ('pytorch/vision:v0.6.0' , 'resnet50' , pretrained = True ),
52
+ "path" : "both"
53
+ },
54
+ "fcn_resnet101" : {
55
+ "model" : torch .hub .load ('pytorch/vision:v0.6.0' , 'fcn_resnet101' , pretrained = True ),
56
+ "path" : "script"
57
+ },
58
+ "ssd" : {
59
+ "model" : torch .hub .load ('NVIDIA/DeepLearningExamples:torchhub' , 'nvidia_ssd' , model_math = "fp32" ),
60
+ "path" : "trace"
61
+ }
62
+ }
18
63
19
64
for n , m in models .items ():
20
65
print ("Downloading {}" .format (n ))
21
- m = m .eval ().cuda ()
22
- x = torch .ones ((1 , 3 , 224 , 224 )).cuda ()
23
- trace_model = torch .jit .trace (m , x )
24
- torch .jit .save (trace_model , n + '_traced.jit.pt' )
25
- script_model = torch .jit .script (m )
26
- torch .jit .save (script_model , n + '_scripted.jit.pt' )
66
+ m ["model" ] = m ["model" ].eval ().cuda ()
67
+ x = torch .ones ((1 , 3 , 300 , 300 )).cuda ()
68
+ if m ["path" ] == "both" or m ["path" ] == "trace" :
69
+ trace_model = torch .jit .trace (m ["model" ], [x ])
70
+ torch .jit .save (trace_model , n + '_traced.jit.pt' )
71
+ if m ["path" ] == "both" or m ["path" ] == "script" :
72
+ script_model = torch .jit .script (m ["model" ])
73
+ torch .jit .save (script_model , n + '_scripted.jit.pt' )
0 commit comments