1
+ import argparse
2
+ import os
3
+ import random
4
+ from datetime import datetime
5
+
6
+ import torch
7
+ import torch .nn as nn
8
+ import torch .nn .functional as F
9
+ import torch .optim as optim
10
+ import torch .utils .data as data
11
+ import torchvision .transforms as transforms
12
+ import torchvision .datasets as datasets
13
+
14
+ from torch .utils .tensorboard import SummaryWriter
15
+
16
+ from vgg16 import vgg16
17
+
18
+ PARSER = argparse .ArgumentParser (description = "VGG16 example to use with TRTorch PTQ" )
19
+ PARSER .add_argument ('--epochs' , default = 300 , type = int , help = "Number of total epochs to train" )
20
+ PARSER .add_argument ('--batch-size' , default = 128 , type = int , help = "Batch size to use when training" )
21
+ PARSER .add_argument ('--lr' , default = 0.1 , type = float , help = "Initial learning rate" )
22
+ PARSER .add_argument ('--drop-ratio' , default = 0. , type = float , help = "Dropout ratio" )
23
+ PARSER .add_argument ('--momentum' , default = 0.9 , type = float , help = "Momentum" )
24
+ PARSER .add_argument ('--weight-decay' , default = 5e-4 , type = float , help = "Weight decay" )
25
+ PARSER .add_argument ('--ckpt-dir' , default = "/tmp/vgg16_ckpts" , type = str , help = "Path to save checkpoints (saved every 10 epochs)" )
26
+ PARSER .add_argument ('--start-from' , default = 0 , type = int , help = "Epoch to resume from (requires a checkpoin in the providied checkpoi" )
27
+ PARSER .add_argument ('--seed' , type = int , help = 'Seed value for rng' )
28
+ PARSER .add_argument ('--tensorboard' , type = str , default = '/tmp/vgg16_logs' , help = 'Location for tensorboard info' )
29
+
30
+ args = PARSER .parse_args ()
31
+ for arg in vars (args ):
32
+ print (' {} {}' .format (arg , getattr (args , arg )))
33
+ state = {k : v for k , v in args ._get_kwargs ()}
34
+
35
+ if args .seed is None :
36
+ args .seed = random .randint (1 , 10000 )
37
+ random .seed (args .seed )
38
+ torch .manual_seed (args .seed )
39
+ torch .cuda .manual_seed_all (args .seed )
40
+ print ("RNG seed used: " , args .seed )
41
+
42
+ now = datetime .now ()
43
+
44
+ timestamp = datetime .timestamp (now )
45
+
46
+ writer = SummaryWriter (args .tensorboard + '/test_' + str (timestamp ))
47
+ classes = ('plane' , 'car' , 'bird' , 'cat' , 'deer' , 'dog' , 'frog' , 'horse' , 'ship' , 'truck' )
48
+
49
+
50
+ def main ():
51
+ global state
52
+ global classes
53
+ global writer
54
+ if not os .path .isdir (args .ckpt_dir ):
55
+ os .makedirs (args .ckpt_dir )
56
+
57
+ training_dataset = datasets .CIFAR10 (root = './data' , train = True ,
58
+ download = True , transform = transforms .Compose ([
59
+ transforms .RandomCrop (32 , padding = 4 ),
60
+ transforms .RandomHorizontalFlip (),
61
+ transforms .ToTensor (),
62
+ transforms .Normalize ((0.4914 , 0.4822 , 0.4465 ),
63
+ (0.2023 , 0.1994 , 0.2010 )),
64
+ ]))
65
+ training_dataloader = torch .utils .data .DataLoader (training_dataset , batch_size = args .batch_size ,
66
+ shuffle = True , num_workers = 2 )
67
+
68
+ testing_dataset = datasets .CIFAR10 (root = './data' , train = False , download = True ,
69
+ transform = transforms .Compose ([
70
+ transforms .ToTensor (),
71
+ transforms .Normalize ((0.4914 , 0.4822 , 0.4465 ),
72
+ (0.2023 , 0.1994 , 0.2010 )),
73
+ ]))
74
+
75
+ testing_dataloader = torch .utils .data .DataLoader (testing_dataset , batch_size = args .batch_size ,
76
+ shuffle = False , num_workers = 2 )
77
+
78
+ num_classes = len (classes )
79
+
80
+ model = vgg16 (num_classes = num_classes , init_weights = False )
81
+ model = model .cuda ()
82
+
83
+ data = iter (training_dataloader )
84
+ images , _ = data .next ()
85
+
86
+ writer .add_graph (model , images .cuda ())
87
+ writer .close ()
88
+
89
+ crit = nn .CrossEntropyLoss ()
90
+ opt = optim .SGD (model .parameters (), lr = args .lr , momentum = args .momentum , weight_decay = args .weight_decay )
91
+
92
+ if args .start_from != 0 :
93
+ ckpt_file = args .ckpt_dir + '/ckpt_epoch' + str (args .start_from ) + '.pth'
94
+ print ('Loading from checkpoint {}' .format (ckpt_file ))
95
+ assert (os .path .isfile (ckpt_file ))
96
+ ckpt = torch .load (ckpt_file )
97
+ model .load_state_dict (ckpt ["model_state_dict" ])
98
+ opt .load_state_dict (ckpt ["opt_state_dict" ])
99
+ state = ckpt ["state" ]
100
+
101
+ if torch .cuda .device_count () > 1 :
102
+ model = nn .DataParallel (model )
103
+
104
+ for epoch in range (args .start_from , args .epochs ):
105
+ adjust_lr (opt , epoch )
106
+ writer .add_scalar ('Learning Rate' , state ["lr" ], epoch )
107
+ writer .close ()
108
+ print ('Epoch: [%5d / %5d] LR: %f' % (epoch + 1 , args .epochs , state ['lr' ]))
109
+
110
+ train (model , training_dataloader , crit , opt , epoch )
111
+ test_loss , test_acc = test (model , testing_dataloader , crit , epoch )
112
+
113
+ print ("Test Loss: {:.5f} Test Acc: {:.2f}%" .format (test_loss , 100 * test_acc ))
114
+
115
+ if epoch % 10 == 9 :
116
+ save_checkpoint ({
117
+ 'epoch' : epoch + 1 ,
118
+ 'model_state_dict' : model .state_dict (),
119
+ 'acc' : test_acc ,
120
+ 'opt_state_dict' : opt .state_dict (),
121
+ 'state' : state
122
+ }, ckpt_dir = args .ckpt_dir )
123
+
124
+ def train (model , dataloader , crit , opt , epoch ):
125
+ global writer
126
+ model .train ()
127
+ running_loss = 0.0
128
+ for batch , (data , labels ) in enumerate (dataloader ):
129
+ data , labels = data .cuda (), labels .cuda (async = True )
130
+ opt .zero_grad ()
131
+ out = model (data )
132
+ loss = crit (out , labels )
133
+ loss .backward ()
134
+ opt .step ()
135
+
136
+ running_loss += loss .item ()
137
+ if batch % 50 == 49 :
138
+ writer .add_scalar ('Training Loss' , running_loss / 100 , epoch * len (dataloader ) + batch )
139
+ writer .close ()
140
+ print ("Batch: [%5d | %5d] loss: %.3f" % (batch + 1 , len (dataloader ), running_loss / 100 ))
141
+ running_loss = 0.0
142
+
143
+ def test (model , dataloader , crit , epoch ):
144
+ global writer
145
+ global classes
146
+ total = 0
147
+ correct = 0
148
+ loss = 0.0
149
+ class_probs = []
150
+ class_preds = []
151
+ model .eval ()
152
+ with torch .no_grad ():
153
+ for data , labels in dataloader :
154
+ data , labels = data .cuda (), labels .cuda (async = True )
155
+ out = model (data )
156
+ loss += crit (out , labels )
157
+ preds = torch .max (out , 1 )[1 ]
158
+ class_probs .append ([F .softmax (i , dim = 0 ) for i in out ])
159
+ class_preds .append (preds )
160
+ total += labels .size (0 )
161
+ correct += (preds == labels ).sum ().item ()
162
+
163
+ writer .add_scalar ('Testing Loss' , loss / total , epoch )
164
+ writer .close ()
165
+
166
+ writer .add_scalar ('Testing Accuracy' , correct / total * 100 , epoch )
167
+ writer .close ()
168
+
169
+ test_probs = torch .cat ([torch .stack (batch ) for batch in class_probs ])
170
+ test_preds = torch .cat (class_preds )
171
+ for i in range (len (classes )):
172
+ add_pr_curve_tensorboard (i , test_probs , test_preds , epoch )
173
+ return loss / total , correct / total
174
+
175
+
176
+ def save_checkpoint (state , ckpt_dir = 'checkpoint' ):
177
+ print ("Checkpoint {} saved" .format (state ['epoch' ]))
178
+ filename = "ckpt_epoch" + str (state ['epoch' ]) + ".pth"
179
+ filepath = os .path .join (ckpt_dir , filename )
180
+ torch .save (state , filepath )
181
+
182
+ def adjust_lr (optimizer , epoch ):
183
+ global state
184
+ new_lr = state ["lr" ] * (0.5 ** (epoch // 50 )) if state ["lr" ] > 1e-7 else state ["lr" ]
185
+ if new_lr != state ["lr" ]:
186
+ state ["lr" ] = new_lr
187
+ print ("Updating learning rate: {}" .format (state ["lr" ]))
188
+ for param_group in optimizer .param_groups :
189
+ param_group ["lr" ] = state ["lr" ]
190
+
191
+ def add_pr_curve_tensorboard (class_index , test_probs , test_preds , global_step = 0 ):
192
+ global classes
193
+ '''
194
+ Takes in a "class_index" from 0 to 9 and plots the corresponding
195
+ precision-recall curve
196
+ '''
197
+ tensorboard_preds = test_preds == class_index
198
+ tensorboard_probs = test_probs [:, class_index ]
199
+
200
+ writer .add_pr_curve (classes [class_index ],
201
+ tensorboard_preds ,
202
+ tensorboard_probs ,
203
+ global_step = global_step )
204
+ writer .close ()
205
+
206
+ if __name__ == "__main__" :
207
+ main ()
0 commit comments