-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathevaluate_on_pgd_attacks.py
115 lines (82 loc) · 3.43 KB
/
evaluate_on_pgd_attacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Copyright (c) 2018-present, Royal Bank of Canada.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import argparse
import torch
import torch.nn as nn
from advertorch.attacks import LinfPGDAttack, L2PGDAttack, ChooseBestAttack
from advertorch.attacks.utils import attack_whole_dataset
from advertorch.loss import CWLoss
from advertorch.utils import get_accuracy
from advertorch_examples.models import LeNet5Madry
from advertorch_examples.models import get_cifar10_wrn28_widen_factor
from advertorch_examples.utils import get_test_loader
def generate_adversaries(attack_class, nb_restart, **kwargs):
nb_cw = nb_restart // 2
nb_ce = nb_restart - nb_cw
adversaries = []
kwargs["loss_fn"] = nn.CrossEntropyLoss(reduction="sum")
ce_pgd = attack_class(**kwargs)
adversaries = [ce_pgd] * nb_ce
if nb_cw > 0:
kwargs["loss_fn"] = CWLoss(reduction="sum")
cw_pgd = attack_class(**kwargs)
adversaries += [cw_pgd] * nb_cw
return adversaries
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--device', default="cuda")
parser.add_argument('--deterministic', default=False, action="store_true")
parser.add_argument('--dataset', required=True)
parser.add_argument('--norm', required=True, choices=("Linf", "L2"))
parser.add_argument('--eps', required=True, type=float)
parser.add_argument('--eps_iter', default=None, type=float)
parser.add_argument('--nb_iter', default=100, type=int)
parser.add_argument('--nb_restart', default=10, type=int)
parser.add_argument('--model', required=True, type=str)
parser.add_argument('--test_size', default=None, type=int)
args = parser.parse_args()
ckpt = torch.load(args.model)
if args.dataset.upper() == "CIFAR10":
model = get_cifar10_wrn28_widen_factor(4)
elif args.dataset.upper() == "MNIST":
model = LeNet5Madry()
else:
raise
model.load_state_dict(ckpt["model"])
model.to(args.device)
model.eval()
print("model loaded")
if args.dataset.upper() == "CIFAR10" and args.norm == "Linf" \
and args.eps > 1.:
args.eps = round(args.eps / 255., 4)
if args.eps_iter is None:
if args.dataset.upper() == "MNIST" and args.norm == "Linf":
args.eps_iter = args.eps / 40.
else:
args.eps_iter = args.eps / 4.
test_loader = get_test_loader(
args.dataset.upper(), test_size=args.test_size, batch_size=100)
if args.norm == "Linf":
attack_class = LinfPGDAttack
elif args.norm == "L2":
attack_class = L2PGDAttack
else:
raise
base_adversaries = generate_adversaries(
attack_class, args.nb_restart, predict=model, eps=args.eps,
nb_iter=args.nb_iter, eps_iter=args.eps_iter, rand_init=True)
adversary = ChooseBestAttack(model, base_adversaries)
adv, label, pred, advpred = attack_whole_dataset(
adversary, test_loader, device=args.device)
print(get_accuracy(advpred, label))
print(get_accuracy(advpred, pred))
torch.save({"adv": adv}, os.path.join(
os.path.dirname(args.model), "advdata_eps-{}.pt".format(args.eps)))
torch.save(
{"label": label, "pred": pred, "advpred": advpred},
os.path.join(os.path.dirname(args.model),
"advlabel_eps-{}.pt".format(args.eps)))