Skip to content

Implemented a Siamese Network Example #1003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 10, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ docs/venv

# vi backups
*~

# development
.vscode
7 changes: 7 additions & 0 deletions siamese_network/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Basic Siamese Network Example

```bash
pip install -r requirements.txt
python main.py
# CUDA_VISIBLE_DEVICES=2 python main.py # to specify GPU id to ex. 2
```
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also reference the paper you're using a baseline for your implementation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have referenced FaceNet, the closest implementation to the example's implementation. This implementation varies from FaceNet as we use the ResNet-18 model as our feature extractor. In addition, we aren't using TripletLoss as the MNIST dataset is simple, so BCELoss can do the trick.

214 changes: 214 additions & 0 deletions siamese_network/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
from __future__ import print_function
import argparse, random, copy
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision import transforms as T
from torch.optim.lr_scheduler import StepLR


class SiameseNetwork(nn.Module):
def __init__(self):
super(SiameseNetwork, self).__init__()
# get resnet model
self.resnet = torchvision.models.resnet18(pretrained=False)

# over-write the first conv layer to be able to read MNIST images
# as resnet18 reads (3,x,x) where 3 is RGB channels
# whereas MNIST has (1,x,x) where 1 is a gray-scale channel
self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
self.fc_in_features = self.resnet.fc.in_features

# remove the last layer of resnet18 (linear layer which is before avgpool layer)
self.resnet = torch.nn.Sequential(*(list(self.resnet.children())[:-1]))

# add linear layers to compare between the features of the two images
self.fc = nn.Sequential(
nn.Linear(self.fc_in_features * 2, 256),
nn.ReLU(inplace=True),

nn.Linear(256, 1),
)

self.sigmoid = nn.Sigmoid()

# initialize the weights
self.resnet.apply(self.init_weights)
self.fc.apply(self.init_weights)

def init_weights(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform(m.weight)
m.bias.data.fill_(0.01)

def forward_once(self, x):
output = self.resnet(x)
output = output.view(output.size()[0], -1)
return output

def forward(self, input1, input2):
# get two images' features
output1 = self.forward_once(input1)
output2 = self.forward_once(input2)

# concatnate both images' features
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

theres's a few simple typos around, a quick spell checker would help

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out! I have fixed it! 😄

output = torch.cat((output1, output2), 1)

# pass the concatnation to the linear layers
output = self.fc(output)
output = self.sigmoid(output)

return output

class APP_MATCHER(Dataset):
def __init__(self, root, train, download=False):
super(APP_MATCHER, self).__init__()
# get MNIST dataset
self.dataset = datasets.MNIST(root, train=train, download=download)

# get targets (labels) and data (images)
self.targets = copy.deepcopy(self.dataset.targets)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the goal of the group sets section? Could you add some comments explaining why the deep copies are needed and what this is doing exactly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed unwanted copying of objects and included detailed comments to explain the process.

self.data = copy.deepcopy(self.dataset.data.unsqueeze(1))

self.group_sets()

def group_sets(self):
np_arr = np.array(self.dataset.targets.clone())
self.grouped_indices = {}
for i in range(0,10):
self.grouped_indices[i] = np.where((np_arr==i))[0]

def __len__(self):
return self.data.shape[0]

def __getitem__(self, index):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add more comments in this section

selected_class = random.randint(0, 9)
random_index_1 = random.randint(0, self.grouped_indices[selected_class].shape[0]-1)
index_1 = self.grouped_indices[selected_class][random_index_1]
image_1 = self.data[index_1].clone().float()

# same class
if index % 2 == 0:
random_index_2 = random.randint(0, self.grouped_indices[selected_class].shape[0]-1)
while random_index_2 == random_index_1:
random_index_2 = random.randint(0, self.grouped_indices[selected_class].shape[0]-1)
index_2 = self.grouped_indices[selected_class][random_index_2]
image_2 = self.data[index_2].clone().float()
target = torch.tensor(1, dtype=torch.float)

# different class
else:
other_selected_class = random.randint(0, 9)
while other_selected_class == selected_class:
other_selected_class = random.randint(0, 9)
random_index_2 = random.randint(0, self.grouped_indices[other_selected_class].shape[0]-1)
index_2 = self.grouped_indices[other_selected_class][random_index_2]
image_2 = self.data[index_2].clone().float()
target = torch.tensor(0, dtype=torch.float)

return image_1, image_2, target


def train(args, model, device, train_loader, optimizer, epoch):
model.train()
criterion = nn.BCELoss()
for batch_idx, (images_1, images_2, targets) in enumerate(train_loader):
images_1, images_2, targets = images_1.to(device), images_2.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(images_1, images_2).squeeze()
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(images_1), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
if args.dry_run:
break


def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
criterion = nn.BCELoss()
with torch.no_grad():
for (images_1, images_2, targets) in test_loader:
images_1, images_2, targets = images_1.to(device), images_2.to(device), targets.to(device)
outputs = model(images_1, images_2).squeeze()
test_loss += criterion(outputs, targets).sum().item() # sum up batch loss
pred = torch.where(outputs > 0.5, 1, 0) # get the index of the max log-probability
correct += pred.eq(targets.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)

print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment around how much loss is expected to be if default settings are used

test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))


def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update to Siamese network

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I missed it. Thanks!

parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=14, metavar='N',
help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--dry-run', action='store_true', default=False,
help='quickly check a single pass')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

train_kwargs = {'batch_size': args.batch_size}
test_kwargs = {'batch_size': args.test_batch_size}
if use_cuda:
cuda_kwargs = {'num_workers': 1,
'pin_memory': True,
'shuffle': True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)

train_dataset = APP_MATCHER('../data', train=True, download=True)
test_dataset = APP_MATCHER('../data', train=False)
train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)

model = SiameseNetwork().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
scheduler.step()

if args.save_model:
torch.save(model.state_dict(), "siamese_network.pt")


if __name__ == '__main__':
main()
2 changes: 2 additions & 0 deletions siamese_network/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torch
torchvision