Skip to content

Commit 676bf56

Browse files
committed
feat(//cpp/ptq/training): Training recipe for VGG16 Classifier on
CIFAR10 for ptq example Gets about 90-91% accuracy, initial LR 0.01, dropout 0.15 Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 8580106 commit 676bf56

File tree

3 files changed

+270
-1
lines changed

3 files changed

+270
-1
lines changed

Diff for: .gitignore

+4-1
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,7 @@ py/tmp/
1515
py/.eggs
1616
.vscode/
1717
.DS_Store
18-
._DS_Store
18+
._DS_Store
19+
*.pth
20+
*.pyc
21+
cpp/ptq/training/vgg16/data/

Diff for: cpp/ptq/training/vgg16/main.py

+207
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
import argparse
2+
import os
3+
import random
4+
from datetime import datetime
5+
6+
import torch
7+
import torch.nn as nn
8+
import torch.nn.functional as F
9+
import torch.optim as optim
10+
import torch.utils.data as data
11+
import torchvision.transforms as transforms
12+
import torchvision.datasets as datasets
13+
14+
from torch.utils.tensorboard import SummaryWriter
15+
16+
from vgg16 import vgg16
17+
18+
PARSER = argparse.ArgumentParser(description="VGG16 example to use with TRTorch PTQ")
19+
PARSER.add_argument('--epochs', default=300, type=int, help="Number of total epochs to train")
20+
PARSER.add_argument('--batch-size', default=128, type=int, help="Batch size to use when training")
21+
PARSER.add_argument('--lr', default=0.1, type=float, help="Initial learning rate")
22+
PARSER.add_argument('--drop-ratio', default=0., type=float, help="Dropout ratio")
23+
PARSER.add_argument('--momentum', default=0.9, type=float, help="Momentum")
24+
PARSER.add_argument('--weight-decay', default=5e-4, type=float, help="Weight decay")
25+
PARSER.add_argument('--ckpt-dir', default="/tmp/vgg16_ckpts", type=str, help="Path to save checkpoints (saved every 10 epochs)")
26+
PARSER.add_argument('--start-from', default=0, type=int, help="Epoch to resume from (requires a checkpoin in the providied checkpoi")
27+
PARSER.add_argument('--seed', type=int, help='Seed value for rng')
28+
PARSER.add_argument('--tensorboard', type=str, default='/tmp/vgg16_logs', help='Location for tensorboard info')
29+
30+
args = PARSER.parse_args()
31+
for arg in vars(args):
32+
print(' {} {}'.format(arg, getattr(args, arg)))
33+
state = {k: v for k, v in args._get_kwargs()}
34+
35+
if args.seed is None:
36+
args.seed = random.randint(1, 10000)
37+
random.seed(args.seed)
38+
torch.manual_seed(args.seed)
39+
torch.cuda.manual_seed_all(args.seed)
40+
print("RNG seed used: ", args.seed)
41+
42+
now = datetime.now()
43+
44+
timestamp = datetime.timestamp(now)
45+
46+
writer = SummaryWriter(args.tensorboard + '/test_' + str(timestamp))
47+
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
48+
49+
50+
def main():
51+
global state
52+
global classes
53+
global writer
54+
if not os.path.isdir(args.ckpt_dir):
55+
os.makedirs(args.ckpt_dir)
56+
57+
training_dataset = datasets.CIFAR10(root='./data', train=True,
58+
download=True, transform=transforms.Compose([
59+
transforms.RandomCrop(32, padding=4),
60+
transforms.RandomHorizontalFlip(),
61+
transforms.ToTensor(),
62+
transforms.Normalize((0.4914, 0.4822, 0.4465),
63+
(0.2023, 0.1994, 0.2010)),
64+
]))
65+
training_dataloader = torch.utils.data.DataLoader(training_dataset, batch_size=args.batch_size,
66+
shuffle=True, num_workers=2)
67+
68+
testing_dataset = datasets.CIFAR10(root='./data', train=False, download=True,
69+
transform=transforms.Compose([
70+
transforms.ToTensor(),
71+
transforms.Normalize((0.4914, 0.4822, 0.4465),
72+
(0.2023, 0.1994, 0.2010)),
73+
]))
74+
75+
testing_dataloader = torch.utils.data.DataLoader(testing_dataset, batch_size=args.batch_size,
76+
shuffle=False, num_workers=2)
77+
78+
num_classes = len(classes)
79+
80+
model = vgg16(num_classes=num_classes, init_weights=False)
81+
model = model.cuda()
82+
83+
data = iter(training_dataloader)
84+
images, _ = data.next()
85+
86+
writer.add_graph(model, images.cuda())
87+
writer.close()
88+
89+
crit = nn.CrossEntropyLoss()
90+
opt = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
91+
92+
if args.start_from != 0:
93+
ckpt_file = args.ckpt_dir + '/ckpt_epoch' + str(args.start_from) + '.pth'
94+
print('Loading from checkpoint {}'.format(ckpt_file))
95+
assert(os.path.isfile(ckpt_file))
96+
ckpt = torch.load(ckpt_file)
97+
model.load_state_dict(ckpt["model_state_dict"])
98+
opt.load_state_dict(ckpt["opt_state_dict"])
99+
state = ckpt["state"]
100+
101+
if torch.cuda.device_count() > 1:
102+
model = nn.DataParallel(model)
103+
104+
for epoch in range(args.start_from, args.epochs):
105+
adjust_lr(opt, epoch)
106+
writer.add_scalar('Learning Rate', state["lr"], epoch)
107+
writer.close()
108+
print('Epoch: [%5d / %5d] LR: %f' % (epoch + 1, args.epochs, state['lr']))
109+
110+
train(model, training_dataloader, crit, opt, epoch)
111+
test_loss, test_acc = test(model, testing_dataloader, crit, epoch)
112+
113+
print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))
114+
115+
if epoch % 10 == 9:
116+
save_checkpoint({
117+
'epoch': epoch + 1,
118+
'model_state_dict': model.state_dict(),
119+
'acc': test_acc,
120+
'opt_state_dict' : opt.state_dict(),
121+
'state': state
122+
}, ckpt_dir=args.ckpt_dir)
123+
124+
def train(model, dataloader, crit, opt, epoch):
125+
global writer
126+
model.train()
127+
running_loss = 0.0
128+
for batch, (data, labels) in enumerate(dataloader):
129+
data, labels = data.cuda(), labels.cuda(async=True)
130+
opt.zero_grad()
131+
out = model(data)
132+
loss = crit(out, labels)
133+
loss.backward()
134+
opt.step()
135+
136+
running_loss += loss.item()
137+
if batch % 50 == 49:
138+
writer.add_scalar('Training Loss', running_loss / 100, epoch * len(dataloader) + batch)
139+
writer.close()
140+
print("Batch: [%5d | %5d] loss: %.3f" % (batch + 1, len(dataloader), running_loss / 100))
141+
running_loss = 0.0
142+
143+
def test(model, dataloader, crit, epoch):
144+
global writer
145+
global classes
146+
total = 0
147+
correct = 0
148+
loss = 0.0
149+
class_probs = []
150+
class_preds = []
151+
model.eval()
152+
with torch.no_grad():
153+
for data, labels in dataloader:
154+
data, labels = data.cuda(), labels.cuda(async=True)
155+
out = model(data)
156+
loss += crit(out, labels)
157+
preds = torch.max(out, 1)[1]
158+
class_probs.append([F.softmax(i, dim=0) for i in out])
159+
class_preds.append(preds)
160+
total += labels.size(0)
161+
correct += (preds == labels).sum().item()
162+
163+
writer.add_scalar('Testing Loss', loss / total, epoch)
164+
writer.close()
165+
166+
writer.add_scalar('Testing Accuracy', correct / total * 100, epoch)
167+
writer.close()
168+
169+
test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
170+
test_preds = torch.cat(class_preds)
171+
for i in range(len(classes)):
172+
add_pr_curve_tensorboard(i, test_probs, test_preds, epoch)
173+
return loss / total, correct / total
174+
175+
176+
def save_checkpoint(state, ckpt_dir='checkpoint'):
177+
print("Checkpoint {} saved".format(state['epoch']))
178+
filename = "ckpt_epoch" + str(state['epoch']) + ".pth"
179+
filepath = os.path.join(ckpt_dir, filename)
180+
torch.save(state, filepath)
181+
182+
def adjust_lr(optimizer, epoch):
183+
global state
184+
new_lr = state["lr"] * (0.5 ** (epoch // 50)) if state["lr"] > 1e-7 else state["lr"]
185+
if new_lr != state["lr"]:
186+
state["lr"] = new_lr
187+
print("Updating learning rate: {}".format(state["lr"]))
188+
for param_group in optimizer.param_groups:
189+
param_group["lr"] = state["lr"]
190+
191+
def add_pr_curve_tensorboard(class_index, test_probs, test_preds, global_step=0):
192+
global classes
193+
'''
194+
Takes in a "class_index" from 0 to 9 and plots the corresponding
195+
precision-recall curve
196+
'''
197+
tensorboard_preds = test_preds == class_index
198+
tensorboard_probs = test_probs[:, class_index]
199+
200+
writer.add_pr_curve(classes[class_index],
201+
tensorboard_preds,
202+
tensorboard_probs,
203+
global_step=global_step)
204+
writer.close()
205+
206+
if __name__ == "__main__":
207+
main()

Diff for: cpp/ptq/training/vgg16/vgg16.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from functools import reduce
5+
6+
class VGG(nn.Module):
7+
def __init__(self, layer_spec, num_classes=1000, init_weights=False):
8+
super(VGG, self).__init__()
9+
10+
layers = []
11+
in_channels = 3
12+
for l in layer_spec:
13+
if l == 'pool':
14+
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
15+
else:
16+
layers += [
17+
nn.Conv2d(in_channels, l, kernel_size=3, padding=1),
18+
nn.BatchNorm2d(l),
19+
nn.ReLU()
20+
]
21+
in_channels = l
22+
23+
self.features = nn.Sequential(*layers)
24+
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
25+
self.classifier = nn.Sequential(
26+
nn.Linear(512 * 7 * 7, 4096),
27+
nn.ReLU(),
28+
nn.Dropout(),
29+
nn.Linear(4096, 4096),
30+
nn.ReLU(),
31+
nn.Dropout(),
32+
nn.Linear(4096, num_classes)
33+
)
34+
if init_weights:
35+
self._initialize_weights()
36+
37+
def _initialize_weights(self):
38+
for m in self.modules():
39+
if isinstance(m, nn.Conv2d):
40+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
41+
if m.bias is not None:
42+
nn.init.constant_(m.bias, 0)
43+
elif isinstance(m, nn.BatchNorm2d):
44+
nn.init.constant_(m.weight, 1)
45+
nn.init.constant_(m.bias, 0)
46+
elif isinstance(m, nn.Linear):
47+
nn.init.normal_(m.weight, 0, 0.01)
48+
nn.init.constant_(m.bias, 0)
49+
50+
def forward(self, x):
51+
x = self.features(x)
52+
x = self.avgpool(x)
53+
x = torch.flatten(x,1)
54+
x = self.classifier(x)
55+
return x
56+
57+
def vgg16(num_classes=1000, init_weights=False):
58+
vgg16_cfg = [64, 64, 'pool', 128, 128, 'pool', 256, 256, 256, 256, 'pool', 512, 512, 512, 512, 'pool', 512, 512, 512, 512, 'pool']
59+
return VGG(vgg16_cfg, num_classes, init_weights)

0 commit comments

Comments
 (0)