-
-
Notifications
You must be signed in to change notification settings - Fork 650
Issue#2878: Adds Siamese Network example #2882
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 5 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
2df93b7
Issue#2878: Adds Siamese Network example
55edb00
Update README.md
cac0692
Updated code formatting
a2341e3
Updated more code formatting errors
32f1c30
Updated some more code formatting errors
ce8e43b
Update dataset, loss function and minor fixes
96c8b22
Merge branch 'master' into deepc004/issue#2878
DeepC004 feea719
Merge branch 'master' into deepc004/issue#2878
DeepC004 d1b66e7
Merge branch 'master' into deepc004/issue#2878
DeepC004 081f1b6
Code refactoring and bottleneck removal
a4e545d
Merge branch 'deepc004/issue#2878' of https://github.com/DeepC004/ign…
1bfb81d
Added accuracy measures
a958372
Merge branch 'master' into deepc004/issue#2878
DeepC004 552d441
added ignite.metrics.Accuracy + minor changes
84c945a
code formatting
faa39f0
minor fixes
46f10f6
Merge branch 'master' into deepc004/issue#2878
vfdev-5 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Siamese Network example on MNIST dataset | ||
|
||
This example is ported over from [pytorch/examples](https://github.com/pytorch/examples) | ||
|
||
Usage: | ||
|
||
``` | ||
pip install -r requirements.txt | ||
python siamese_network.py | ||
``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
torch | ||
torchvision | ||
pytorch-ignite |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,325 @@ | ||
from __future__ import print_function | ||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
import argparse | ||
import random | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import torchvision | ||
from torch.optim.lr_scheduler import StepLR | ||
from torch.utils.data import DataLoader | ||
from torchvision import datasets | ||
|
||
from ignite.contrib.handlers import ProgressBar | ||
from ignite.engine import Engine, Events | ||
from ignite.handlers.param_scheduler import LRScheduler | ||
|
||
|
||
class SiameseNetwork(nn.Module): | ||
# update Siamese Network implementation in accordance with the dataset | ||
""" | ||
Siamese network for image similarity estimation. | ||
The network is composed of two identical networks, one for each input. | ||
The output of each network is concatenated and passed to a linear layer. | ||
The output of the linear layer passed through a sigmoid function. | ||
`"FaceNet" <https://arxiv.org/pdf/1503.03832.pdf>`_ is a variant of the Siamese network. | ||
This implementation varies from FaceNet as we use the `ResNet-18` model from | ||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>` | ||
as our feature extractor. | ||
In addition, we aren't using `TripletLoss` as the MNIST dataset is simple, so `BCELoss` can do the trick. | ||
""" | ||
|
||
def __init__(self): | ||
super(SiameseNetwork, self).__init__() | ||
# get resnet model | ||
self.resnet = torchvision.models.resnet18(weights=None) | ||
|
||
# over-write the first conv layer to be able to read MNIST images | ||
# as resnet18 reads (3,x,x) where 3 is RGB channels | ||
# whereas MNIST has (1,x,x) where 1 is a gray-scale channel | ||
self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) | ||
self.fc_in_features = self.resnet.fc.in_features | ||
DeepC004 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# remove the last layer of resnet18 (linear layer which is before avgpool layer) | ||
self.resnet = torch.nn.Sequential(*(list(self.resnet.children())[:-1])) | ||
|
||
# add linear layers to compare between the features of the two images | ||
self.fc = nn.Sequential( | ||
nn.Linear(self.fc_in_features * 2, 256), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(256, 1), | ||
) | ||
|
||
self.sigmoid = nn.Sigmoid() | ||
|
||
# initialize the weights | ||
self.resnet.apply(self.init_weights) | ||
self.fc.apply(self.init_weights) | ||
|
||
def init_weights(self, m): | ||
if isinstance(m, nn.Linear): | ||
torch.nn.init.xavier_uniform_(m.weight) | ||
m.bias.data.fill_(0.01) | ||
|
||
def forward_once(self, x): | ||
output = self.resnet(x) | ||
output = output.view(output.size()[0], -1) | ||
return output | ||
|
||
def forward(self, input1, input2): | ||
# get two images' features | ||
output1 = self.forward_once(input1) | ||
output2 = self.forward_once(input2) | ||
|
||
# concatenate both images' feature | ||
output = torch.cat((output1, output2), 1) | ||
|
||
# pass the concatenation to the linear layers | ||
output = self.fc(output) | ||
|
||
# pass the out of the linear layers to sigmoid layer | ||
output = self.sigmoid(output) | ||
|
||
return output | ||
|
||
|
||
class APP_MATCHER: | ||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# following class implements data downloading and handles preprocessing | ||
def __init__(self, root, train, download=False): | ||
super(APP_MATCHER, self).__init__() | ||
|
||
# get MNIST dataset | ||
self.dataset = datasets.MNIST(root, train=train, download=download) | ||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# as `self.dataset.data`'s shape is (Nx28x28), where N is the number of | ||
# examples in MNIST dataset, a single example has the dimensions of | ||
# (28x28) for (WxH), where W and H are the width and the height of the image. | ||
# However, every example should have (CxWxH) dimensions where C is the number | ||
# of channels to be passed to the network. As MNIST contains gray-scale images, | ||
# we add an additional dimension to corresponds to the number of channels. | ||
self.data = self.dataset.data.unsqueeze(1).clone() | ||
|
||
self.group_examples() | ||
|
||
def group_examples(self): | ||
""" | ||
To ease the accessibility of data based on the class, we will use `group_examples` to group | ||
examples based on class. | ||
|
||
Every key in `grouped_examples` corresponds to a class in MNIST dataset. For every key in | ||
`grouped_examples`, every value will conform to all of the indices for the MNIST | ||
dataset examples that correspond to that key. | ||
""" | ||
|
||
# get the targets from MNIST dataset | ||
np_arr = np.array(self.dataset.targets.clone()) | ||
|
||
# group examples based on class | ||
self.grouped_examples = {} | ||
for i in range(0, 10): | ||
self.grouped_examples[i] = np.where((np_arr == i))[0] | ||
|
||
def __len__(self): | ||
return self.data.shape[0] | ||
|
||
def __getitem__(self, index): | ||
""" | ||
For every example, we will select two images. There are two cases, | ||
positive and negative examples. For positive examples, we will have two | ||
images from the same class. For negative examples, we will have two images | ||
from different classes. | ||
|
||
Given an index, if the index is even, we will pick the second image from the same class, | ||
but it won't be the same image we chose for the first class. This is used to ensure the positive | ||
example isn't trivial as the network would easily distinguish the similarity between same images. However, | ||
if the network were given two different images from the same class, the network will need to learn | ||
the similarity between two different images representing the same class. If the index is odd, we will | ||
pick the second image from a different class than the first image. | ||
""" | ||
|
||
# pick some random class for the first image | ||
selected_class = random.randint(0, 9) | ||
|
||
# pick a random index for the first image in the grouped indices based of the label | ||
# of the class | ||
random_index_1 = random.randint(0, self.grouped_examples[selected_class].shape[0] - 1) | ||
|
||
# pick the index to get the first image | ||
index_1 = self.grouped_examples[selected_class][random_index_1] | ||
|
||
# get the first image | ||
image_1 = self.data[index_1].clone().float() | ||
|
||
# same class | ||
if index % 2 == 0: | ||
# pick a random index for the second image | ||
random_index_2 = random.randint(0, self.grouped_examples[selected_class].shape[0] - 1) | ||
|
||
# ensure that the index of the second image isn't the same as the first image | ||
while random_index_2 == random_index_1: | ||
random_index_2 = random.randint(0, self.grouped_examples[selected_class].shape[0] - 1) | ||
|
||
# pick the index to get the second image | ||
index_2 = self.grouped_examples[selected_class][random_index_2] | ||
|
||
# get the second image | ||
image_2 = self.data[index_2].clone().float() | ||
|
||
# set the label for this example to be positive (1) | ||
target = torch.tensor(1, dtype=torch.float) | ||
|
||
# different class | ||
else: | ||
# pick a random class | ||
other_selected_class = random.randint(0, 9) | ||
|
||
# ensure that the class of the second image isn't the same as the first image | ||
while other_selected_class == selected_class: | ||
other_selected_class = random.randint(0, 9) | ||
|
||
# pick a random index for the second image in the grouped indices based of the label | ||
# of the class | ||
random_index_2 = random.randint(0, self.grouped_examples[other_selected_class].shape[0] - 1) | ||
|
||
# pick the index to get the second image | ||
index_2 = self.grouped_examples[other_selected_class][random_index_2] | ||
|
||
# get the second image | ||
image_2 = self.data[index_2].clone().float() | ||
|
||
# set the label for this example to be negative (0) | ||
target = torch.tensor(0, dtype=torch.float) | ||
|
||
return image_1, image_2, target | ||
|
||
|
||
def train(model, device, optimizer, train_loader, lr_scheduler, log_interval, max_epochs): | ||
|
||
# we aren't using `TripletLoss` as the MNIST dataset is simple, so `BCELoss` can do the trick. | ||
criterion = nn.BCELoss() | ||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# define model training step | ||
def train_step(engine, batch): | ||
model.train() | ||
image_1, image_2, target = batch | ||
image_1, image_2, target = image_1.to(device), image_2.to(device), target.to(device) | ||
optimizer.zero_grad() | ||
outputs = model( | ||
image_1, | ||
image_2, | ||
).squeeze() | ||
loss = criterion(outputs, target) | ||
loss.backward() | ||
optimizer.step() | ||
return loss | ||
|
||
# create a trainer engine and attach train_step | ||
trainer = Engine(train_step) | ||
|
||
# attach progress bar to trainer | ||
pbar = ProgressBar() | ||
pbar.attach(trainer) | ||
|
||
# attach various handlers to trainer engine | ||
@trainer.on(Events.ITERATION_COMPLETED(every=log_interval)) | ||
def log_training_results(engine): | ||
print(f"Train Epoch: {engine.state.epoch}, Train Loss: {engine.state.output: .5f}") | ||
DeepC004 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler) | ||
|
||
# run trainer engine | ||
trainer.run(train_loader, max_epochs=max_epochs) | ||
|
||
|
||
def test(model, device, test_loader, lr_scheduler, log_interval): | ||
|
||
# we aren't using `TripletLoss` as the MNIST dataset is simple, so `BCELoss` can do the trick. | ||
criterion = nn.BCELoss() | ||
average_test_loss = 0 | ||
|
||
# define model testing step | ||
def test_step(engine, batch): | ||
model.eval() | ||
image_1, image_2, target = batch | ||
image_1, image_2, target = image_1.to(device), image_2.to(device), target.to(device) | ||
outputs = model(image_1, image_2).squeeze() | ||
test_loss = criterion(outputs, target) | ||
return test_loss | ||
|
||
# create evaluator engine and attach test step | ||
evaluator = Engine(test_step) | ||
|
||
# attach progress bar to evaluator | ||
pbar = ProgressBar() | ||
pbar.attach(evaluator) | ||
|
||
# attach various handlers to evaluator engine | ||
@evaluator.on(Events.ITERATION_COMPLETED(every=log_interval)) | ||
def log_testing_results(engine): | ||
nonlocal average_test_loss | ||
average_test_loss += engine.state.output | ||
print(f"Test Epoch: {engine.state.epoch} Test Loss: {engine.state.output: .5f}") | ||
|
||
evaluator.add_event_handler(Events.ITERATION_STARTED, lr_scheduler) | ||
DeepC004 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# run evaluator engine | ||
evaluator.run(test_loader) | ||
|
||
# print average loss over test dataset | ||
print(f"Average Test Loss: {average_test_loss/len(test_loader.dataset): .7f}") | ||
|
||
|
||
def main(): | ||
# adds training defaults and support for terminal arguments | ||
parser = argparse.ArgumentParser(description="PyTorch Siamese network Example") | ||
parser.add_argument( | ||
"--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)" | ||
) | ||
parser.add_argument( | ||
"--test-batch-size", type=int, default=200, metavar="N", help="input batch size for testing (default: 1000)" | ||
) | ||
parser.add_argument("--epochs", type=int, default=10, metavar="N", help="number of epochs to train (default: 14)") | ||
parser.add_argument("--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)") | ||
parser.add_argument( | ||
"--gamma", type=float, default=0.95, metavar="M", help="Learning rate step gamma (default: 0.7)" | ||
) | ||
parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") | ||
parser.add_argument("--no-mps", action="store_true", default=False, help="disables macOS GPU training") | ||
parser.add_argument("--dry-run", action="store_true", default=False, help="quickly check a single pass") | ||
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") | ||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
parser.add_argument( | ||
"--log-interval", | ||
type=int, | ||
default=10, | ||
metavar="N", | ||
help="how many batches to wait before logging training status", | ||
) | ||
parser.add_argument("--save-model", action="store_true", default=False, help="For Saving the current Model") | ||
args = parser.parse_args() | ||
|
||
# set device | ||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | ||
|
||
# data loading | ||
train_dataset = APP_MATCHER("../data", train=True, download=True) | ||
test_dataset = APP_MATCHER("../data", train=False) | ||
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size) | ||
DeepC004 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size) | ||
|
||
# set model parameters | ||
model = SiameseNetwork().to(device) | ||
optimizer = optim.Adadelta(model.parameters(), lr=args.lr) | ||
scheduler = StepLR(optimizer, step_size=15, gamma=args.gamma) | ||
lr_scheduler = LRScheduler(scheduler) | ||
|
||
# call train function | ||
train(model, device, optimizer, train_loader, lr_scheduler, log_interval=args.log_interval, max_epochs=args.epochs) | ||
|
||
# call test function | ||
test(model, device, test_loader, lr_scheduler, log_interval=args.log_interval) | ||
DeepC004 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.