Skip to content

Commit 9a44540

Browse files
QasimKhan5xSvetlana Karslioglu
and
Svetlana Karslioglu
authored
Fix inconsistencies in fgsm_tutorial (#2419)
* Change model architecture Model architecture was not the same as that of the one in Basic MNIST Example, so it has been changed to be the exact same * Add normalization transform in dataloader The model is trained on normalized data, so it is unfair to use unnormalized data in this example. * Add denormalization code The MNIST model is trained with normalized data but no normalization was applied in this tutorial. Thus, a denorm function is created, which is called to denorm the data before performing FGSM. The perturbed data is again normalized before feeding it to the model. * Update command to download fgsm MNIST weights --------- Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent 9fa95f0 commit 9a44540

File tree

2 files changed

+51
-17
lines changed

2 files changed

+51
-17
lines changed

Diff for: Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ download:
8282
tar $(TAROPTS) -xzf $(DATADIR)/UrbanSound8K.tar.gz -C ./beginner_source/data/
8383

8484
# Download model for beginner_source/fgsm_tutorial.py
85-
wget -nv -N https://s3.amazonaws.com/pytorch-tutorial-assets/lenet_mnist_model.pth -P $(DATADIR)
85+
wget -nv -N 'https://docs.google.com/uc?export=download&id=1HJV2nUHJqclXQ8flKvcWmjZ-OU5DGatl' -O $(DATADIR)/lenet_mnist_model.pth
8686
cp $(DATADIR)/lenet_mnist_model.pth ./beginner_source/data/lenet_mnist_model.pth
8787

8888
# Download model for advanced_source/dynamic_quantization_tutorial.py

Diff for: beginner_source/fgsm_tutorial.py

+50-16
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@
123123
# - ``pretrained_model`` - path to the pretrained MNIST model which was
124124
# trained with
125125
# `pytorch/examples/mnist <https://github.com/pytorch/examples/tree/master/mnist>`__.
126-
# For simplicity, download the pretrained model `here <https://drive.google.com/drive/folders/1fn83DF14tWmit0RTKWRhPq5uVXt73e0h?usp=sharing>`__.
126+
# For simplicity, download the pretrained model `here <https://drive.google.com/file/d/1HJV2nUHJqclXQ8flKvcWmjZ-OU5DGatl/view?usp=drive_link>`__.
127127
#
128128
# - ``use_cuda`` - boolean flag to use CUDA if desired and available.
129129
# Note, a GPU with CUDA is not critical for this tutorial as a CPU will
@@ -154,26 +154,34 @@
154154
class Net(nn.Module):
155155
def __init__(self):
156156
super(Net, self).__init__()
157-
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
158-
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
159-
self.conv2_drop = nn.Dropout2d()
160-
self.fc1 = nn.Linear(320, 50)
161-
self.fc2 = nn.Linear(50, 10)
157+
self.conv1 = nn.Conv2d(1, 32, 3, 1)
158+
self.conv2 = nn.Conv2d(32, 64, 3, 1)
159+
self.dropout1 = nn.Dropout(0.25)
160+
self.dropout2 = nn.Dropout(0.5)
161+
self.fc1 = nn.Linear(9216, 128)
162+
self.fc2 = nn.Linear(128, 10)
162163

163164
def forward(self, x):
164-
x = F.relu(F.max_pool2d(self.conv1(x), 2))
165-
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
166-
x = x.view(-1, 320)
167-
x = F.relu(self.fc1(x))
168-
x = F.dropout(x, training=self.training)
165+
x = self.conv1(x)
166+
x = F.relu(x)
167+
x = self.conv2(x)
168+
x = F.relu(x)
169+
x = F.max_pool2d(x, 2)
170+
x = self.dropout1(x)
171+
x = torch.flatten(x, 1)
172+
x = self.fc1(x)
173+
x = F.relu(x)
174+
x = self.dropout2(x)
169175
x = self.fc2(x)
170-
return F.log_softmax(x, dim=1)
176+
output = F.log_softmax(x, dim=1)
177+
return output
171178

172179
# MNIST Test dataset and dataloader declaration
173180
test_loader = torch.utils.data.DataLoader(
174181
datasets.MNIST('../data', train=False, download=True, transform=transforms.Compose([
175182
transforms.ToTensor(),
176-
])),
183+
transforms.Normalize((0.1307,), (0.3081,)),
184+
])),
177185
batch_size=1, shuffle=True)
178186

179187
# Define what device we are using
@@ -184,7 +192,7 @@ def forward(self, x):
184192
model = Net().to(device)
185193

186194
# Load the pretrained model
187-
model.load_state_dict(torch.load(pretrained_model, weights_only=True, map_location='cpu'))
195+
model.load_state_dict(torch.load(pretrained_model, map_location=device))
188196

189197
# Set the model in evaluation mode. In this case this is for the Dropout layers
190198
model.eval()
@@ -219,6 +227,26 @@ def fgsm_attack(image, epsilon, data_grad):
219227
# Return the perturbed image
220228
return perturbed_image
221229

230+
# restores the tensors to their original scale
231+
def denorm(batch, mean=[0.1307], std=[0.3081]):
232+
"""
233+
Convert a batch of tensors to their original scale.
234+
235+
Args:
236+
batch (torch.Tensor): Batch of normalized tensors.
237+
mean (torch.Tensor or list): Mean used for normalization.
238+
std (torch.Tensor or list): Standard deviation used for normalization.
239+
240+
Returns:
241+
torch.Tensor: batch of tensors without normalization applied to them.
242+
"""
243+
if isinstance(mean, list):
244+
mean = torch.tensor(mean).to(device)
245+
if isinstance(std, list):
246+
std = torch.tensor(std).to(device)
247+
248+
return batch * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1)
249+
222250

223251
######################################################################
224252
# Testing Function
@@ -273,11 +301,17 @@ def test( model, device, test_loader, epsilon ):
273301
# Collect ``datagrad``
274302
data_grad = data.grad.data
275303

304+
# Restore the data to its original scale
305+
data_denorm = denorm(data)
306+
276307
# Call FGSM Attack
277-
perturbed_data = fgsm_attack(data, epsilon, data_grad)
308+
perturbed_data = fgsm_attack(data_denorm, epsilon, data_grad)
309+
310+
# Reapply normalization
311+
perturbed_data_normalized = transforms.Normalize((0.1307,), (0.3081,))(perturbed_data)
278312

279313
# Re-classify the perturbed image
280-
output = model(perturbed_data)
314+
output = model(perturbed_data_normalized)
281315

282316
# Check for success
283317
final_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability

0 commit comments

Comments
 (0)