-
Notifications
You must be signed in to change notification settings - Fork 9.6k
Implemented a Siamese Network Example #1003
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
933d976
aec9e70
11cb38f
57903ba
7cb4788
8c17275
332e138
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,3 +17,6 @@ docs/venv | |
|
||
# vi backups | ||
*~ | ||
|
||
# development | ||
.vscode |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Basic Siamese Network Example | ||
|
||
```bash | ||
pip install -r requirements.txt | ||
python main.py | ||
# CUDA_VISIBLE_DEVICES=2 python main.py # to specify GPU id to ex. 2 | ||
``` | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
from __future__ import print_function | ||
import argparse, random, copy | ||
import numpy as np | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
import torchvision | ||
from torch.utils.data import Dataset | ||
from torchvision import datasets | ||
from torchvision import transforms as T | ||
from torch.optim.lr_scheduler import StepLR | ||
|
||
|
||
class SiameseNetwork(nn.Module): | ||
def __init__(self): | ||
super(SiameseNetwork, self).__init__() | ||
# get resnet model | ||
self.resnet = torchvision.models.resnet18(pretrained=False) | ||
|
||
# over-write the first conv layer to be able to read MNIST images | ||
# as resnet18 reads (3,x,x) where 3 is RGB channels | ||
# whereas MNIST has (1,x,x) where 1 is a gray-scale channel | ||
self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) | ||
self.fc_in_features = self.resnet.fc.in_features | ||
|
||
# remove the last layer of resnet18 (linear layer which is before avgpool layer) | ||
self.resnet = torch.nn.Sequential(*(list(self.resnet.children())[:-1])) | ||
|
||
# add linear layers to compare between the features of the two images | ||
self.fc = nn.Sequential( | ||
nn.Linear(self.fc_in_features * 2, 256), | ||
nn.ReLU(inplace=True), | ||
|
||
nn.Linear(256, 1), | ||
) | ||
|
||
self.sigmoid = nn.Sigmoid() | ||
|
||
# initialize the weights | ||
self.resnet.apply(self.init_weights) | ||
self.fc.apply(self.init_weights) | ||
|
||
def init_weights(self, m): | ||
if isinstance(m, nn.Linear): | ||
torch.nn.init.xavier_uniform(m.weight) | ||
m.bias.data.fill_(0.01) | ||
|
||
def forward_once(self, x): | ||
output = self.resnet(x) | ||
output = output.view(output.size()[0], -1) | ||
return output | ||
|
||
def forward(self, input1, input2): | ||
# get two images' features | ||
output1 = self.forward_once(input1) | ||
output2 = self.forward_once(input2) | ||
|
||
# concatnate both images' features | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. theres's a few simple typos around, a quick spell checker would help There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for pointing it out! I have fixed it! 😄 |
||
output = torch.cat((output1, output2), 1) | ||
|
||
# pass the concatnation to the linear layers | ||
output = self.fc(output) | ||
output = self.sigmoid(output) | ||
|
||
return output | ||
|
||
class APP_MATCHER(Dataset): | ||
def __init__(self, root, train, download=False): | ||
super(APP_MATCHER, self).__init__() | ||
# get MNIST dataset | ||
self.dataset = datasets.MNIST(root, train=train, download=download) | ||
|
||
# get targets (labels) and data (images) | ||
self.targets = copy.deepcopy(self.dataset.targets) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the goal of the group sets section? Could you add some comments explaining why the deep copies are needed and what this is doing exactly? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have removed unwanted copying of objects and included detailed comments to explain the process. |
||
self.data = copy.deepcopy(self.dataset.data.unsqueeze(1)) | ||
|
||
self.group_sets() | ||
|
||
def group_sets(self): | ||
np_arr = np.array(self.dataset.targets.clone()) | ||
self.grouped_indices = {} | ||
for i in range(0,10): | ||
self.grouped_indices[i] = np.where((np_arr==i))[0] | ||
|
||
def __len__(self): | ||
return self.data.shape[0] | ||
|
||
def __getitem__(self, index): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add more comments in this section |
||
selected_class = random.randint(0, 9) | ||
random_index_1 = random.randint(0, self.grouped_indices[selected_class].shape[0]-1) | ||
index_1 = self.grouped_indices[selected_class][random_index_1] | ||
image_1 = self.data[index_1].clone().float() | ||
|
||
# same class | ||
if index % 2 == 0: | ||
random_index_2 = random.randint(0, self.grouped_indices[selected_class].shape[0]-1) | ||
while random_index_2 == random_index_1: | ||
random_index_2 = random.randint(0, self.grouped_indices[selected_class].shape[0]-1) | ||
index_2 = self.grouped_indices[selected_class][random_index_2] | ||
image_2 = self.data[index_2].clone().float() | ||
target = torch.tensor(1, dtype=torch.float) | ||
|
||
# different class | ||
else: | ||
other_selected_class = random.randint(0, 9) | ||
while other_selected_class == selected_class: | ||
other_selected_class = random.randint(0, 9) | ||
random_index_2 = random.randint(0, self.grouped_indices[other_selected_class].shape[0]-1) | ||
index_2 = self.grouped_indices[other_selected_class][random_index_2] | ||
image_2 = self.data[index_2].clone().float() | ||
target = torch.tensor(0, dtype=torch.float) | ||
|
||
return image_1, image_2, target | ||
|
||
|
||
def train(args, model, device, train_loader, optimizer, epoch): | ||
model.train() | ||
criterion = nn.BCELoss() | ||
for batch_idx, (images_1, images_2, targets) in enumerate(train_loader): | ||
images_1, images_2, targets = images_1.to(device), images_2.to(device), targets.to(device) | ||
optimizer.zero_grad() | ||
outputs = model(images_1, images_2).squeeze() | ||
loss = criterion(outputs, targets) | ||
loss.backward() | ||
optimizer.step() | ||
if batch_idx % args.log_interval == 0: | ||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | ||
epoch, batch_idx * len(images_1), len(train_loader.dataset), | ||
100. * batch_idx / len(train_loader), loss.item())) | ||
if args.dry_run: | ||
break | ||
|
||
|
||
def test(model, device, test_loader): | ||
model.eval() | ||
test_loss = 0 | ||
correct = 0 | ||
criterion = nn.BCELoss() | ||
with torch.no_grad(): | ||
for (images_1, images_2, targets) in test_loader: | ||
images_1, images_2, targets = images_1.to(device), images_2.to(device), targets.to(device) | ||
outputs = model(images_1, images_2).squeeze() | ||
test_loss += criterion(outputs, targets).sum().item() # sum up batch loss | ||
pred = torch.where(outputs > 0.5, 1, 0) # get the index of the max log-probability | ||
correct += pred.eq(targets.view_as(pred)).sum().item() | ||
|
||
test_loss /= len(test_loader.dataset) | ||
|
||
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add comment around how much loss is expected to be if default settings are used |
||
test_loss, correct, len(test_loader.dataset), | ||
100. * correct / len(test_loader.dataset))) | ||
|
||
|
||
def main(): | ||
# Training settings | ||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update to Siamese network There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I missed it. Thanks! |
||
parser.add_argument('--batch-size', type=int, default=64, metavar='N', | ||
help='input batch size for training (default: 64)') | ||
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', | ||
help='input batch size for testing (default: 1000)') | ||
parser.add_argument('--epochs', type=int, default=14, metavar='N', | ||
help='number of epochs to train (default: 14)') | ||
parser.add_argument('--lr', type=float, default=1.0, metavar='LR', | ||
help='learning rate (default: 1.0)') | ||
parser.add_argument('--gamma', type=float, default=0.7, metavar='M', | ||
help='Learning rate step gamma (default: 0.7)') | ||
parser.add_argument('--no-cuda', action='store_true', default=False, | ||
help='disables CUDA training') | ||
parser.add_argument('--dry-run', action='store_true', default=False, | ||
help='quickly check a single pass') | ||
parser.add_argument('--seed', type=int, default=1, metavar='S', | ||
help='random seed (default: 1)') | ||
parser.add_argument('--log-interval', type=int, default=10, metavar='N', | ||
help='how many batches to wait before logging training status') | ||
parser.add_argument('--save-model', action='store_true', default=False, | ||
help='For Saving the current Model') | ||
args = parser.parse_args() | ||
use_cuda = not args.no_cuda and torch.cuda.is_available() | ||
|
||
torch.manual_seed(args.seed) | ||
|
||
device = torch.device("cuda" if use_cuda else "cpu") | ||
|
||
train_kwargs = {'batch_size': args.batch_size} | ||
test_kwargs = {'batch_size': args.test_batch_size} | ||
if use_cuda: | ||
cuda_kwargs = {'num_workers': 1, | ||
'pin_memory': True, | ||
'shuffle': True} | ||
train_kwargs.update(cuda_kwargs) | ||
test_kwargs.update(cuda_kwargs) | ||
|
||
train_dataset = APP_MATCHER('../data', train=True, download=True) | ||
test_dataset = APP_MATCHER('../data', train=False) | ||
train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs) | ||
test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs) | ||
|
||
model = SiameseNetwork().to(device) | ||
optimizer = optim.Adadelta(model.parameters(), lr=args.lr) | ||
|
||
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) | ||
for epoch in range(1, args.epochs + 1): | ||
train(args, model, device, train_loader, optimizer, epoch) | ||
test(model, device, test_loader) | ||
scheduler.step() | ||
|
||
if args.save_model: | ||
torch.save(model.state_dict(), "siamese_network.pt") | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
torch | ||
torchvision |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also reference the paper you're using a baseline for your implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have referenced
FaceNet
, the closest implementation to the example's implementation. This implementation varies from FaceNet as we use theResNet-18
model as our feature extractor. In addition, we aren't usingTripletLoss
as the MNIST dataset is simple, soBCELoss
can do the trick.