-
Notifications
You must be signed in to change notification settings - Fork 9.6k
/
Copy pathmain.py
214 lines (173 loc) · 8.48 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
from __future__ import print_function
import argparse, random, copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision import transforms as T
from torch.optim.lr_scheduler import StepLR
class SiameseNetwork(nn.Module):
def __init__(self):
super(SiameseNetwork, self).__init__()
# get resnet model
self.resnet = torchvision.models.resnet18(pretrained=False)
# over-write the first conv layer to be able to read MNIST images
# as resnet18 reads (3,x,x) where 3 is RGB channels
# whereas MNIST has (1,x,x) where 1 is a gray-scale channel
self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
self.fc_in_features = self.resnet.fc.in_features
# remove the last layer of resnet18 (linear layer which is before avgpool layer)
self.resnet = torch.nn.Sequential(*(list(self.resnet.children())[:-1]))
# add linear layers to compare between the features of the two images
self.fc = nn.Sequential(
nn.Linear(self.fc_in_features * 2, 256),
nn.ReLU(inplace=True),
nn.Linear(256, 1),
)
self.sigmoid = nn.Sigmoid()
# initialize the weights
self.resnet.apply(self.init_weights)
self.fc.apply(self.init_weights)
def init_weights(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform(m.weight)
m.bias.data.fill_(0.01)
def forward_once(self, x):
output = self.resnet(x)
output = output.view(output.size()[0], -1)
return output
def forward(self, input1, input2):
# get two images' features
output1 = self.forward_once(input1)
output2 = self.forward_once(input2)
# concatnate both images' features
output = torch.cat((output1, output2), 1)
# pass the concatnation to the linear layers
output = self.fc(output)
output = self.sigmoid(output)
return output
class APP_MATCHER(Dataset):
def __init__(self, root, train, download=False):
super(APP_MATCHER, self).__init__()
# get MNIST dataset
self.dataset = datasets.MNIST(root, train=train, download=download)
# get targets (labels) and data (images)
self.targets = copy.deepcopy(self.dataset.targets)
self.data = copy.deepcopy(self.dataset.data.unsqueeze(1))
self.group_sets()
def group_sets(self):
np_arr = np.array(self.dataset.targets.clone())
self.grouped_indices = {}
for i in range(0,10):
self.grouped_indices[i] = np.where((np_arr==i))[0]
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index):
selected_class = random.randint(0, 9)
random_index_1 = random.randint(0, self.grouped_indices[selected_class].shape[0]-1)
index_1 = self.grouped_indices[selected_class][random_index_1]
image_1 = self.data[index_1].clone().float()
# same class
if index % 2 == 0:
random_index_2 = random.randint(0, self.grouped_indices[selected_class].shape[0]-1)
while random_index_2 == random_index_1:
random_index_2 = random.randint(0, self.grouped_indices[selected_class].shape[0]-1)
index_2 = self.grouped_indices[selected_class][random_index_2]
image_2 = self.data[index_2].clone().float()
target = torch.tensor(1, dtype=torch.float)
# different class
else:
other_selected_class = random.randint(0, 9)
while other_selected_class == selected_class:
other_selected_class = random.randint(0, 9)
random_index_2 = random.randint(0, self.grouped_indices[other_selected_class].shape[0]-1)
index_2 = self.grouped_indices[other_selected_class][random_index_2]
image_2 = self.data[index_2].clone().float()
target = torch.tensor(0, dtype=torch.float)
return image_1, image_2, target
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
criterion = nn.BCELoss()
for batch_idx, (images_1, images_2, targets) in enumerate(train_loader):
images_1, images_2, targets = images_1.to(device), images_2.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(images_1, images_2).squeeze()
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(images_1), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
if args.dry_run:
break
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
criterion = nn.BCELoss()
with torch.no_grad():
for (images_1, images_2, targets) in test_loader:
images_1, images_2, targets = images_1.to(device), images_2.to(device), targets.to(device)
outputs = model(images_1, images_2).squeeze()
test_loss += criterion(outputs, targets).sum().item() # sum up batch loss
pred = torch.where(outputs > 0.5, 1, 0) # get the index of the max log-probability
correct += pred.eq(targets.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=14, metavar='N',
help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--dry-run', action='store_true', default=False,
help='quickly check a single pass')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
train_kwargs = {'batch_size': args.batch_size}
test_kwargs = {'batch_size': args.test_batch_size}
if use_cuda:
cuda_kwargs = {'num_workers': 1,
'pin_memory': True,
'shuffle': True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
train_dataset = APP_MATCHER('../data', train=True, download=True)
test_dataset = APP_MATCHER('../data', train=False)
train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)
model = SiameseNetwork().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
scheduler.step()
if args.save_model:
torch.save(model.state_dict(), "siamese_network.pt")
if __name__ == '__main__':
main()