Skip to content

Commit 2cdf94e

Browse files
[bic, icarl] Returns all probas in order to compute topk acc.
1 parent f99bfb0 commit 2cdf94e

File tree

2 files changed

+31
-13
lines changed

2 files changed

+31
-13
lines changed

inclearn/models/bic.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ def _eval_task(self, loader):
7272
if self._task > 0:
7373
logits = self._bic(logits)
7474

75+
logits = logits.detach()
76+
7577
ytrue.append(input_dict["targets"].numpy())
76-
ypred.append(torch.softmax(logits, dim=1).argmax(dim=1).cpu().numpy())
78+
ypred.append(torch.softmax(logits, dim=1).cpu().numpy())
7779

7880
ytrue = np.concatenate(ytrue)
7981
ypred = np.concatenate(ypred)

inclearn/models/icarl.py

+28-12
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
import logging
44
import os
55
import pdb
6+
import pickle
67

78
import numpy as np
89
import torch
910
from scipy.spatial.distance import cdist
11+
from torch import nn
1012
from torch.nn import functional as F
1113
from tqdm import tqdm
1214

@@ -35,7 +37,7 @@ def __init__(self, args):
3537
self._disable_progressbar = args.get("no_progressbar", False)
3638

3739
self._device = args["device"][0]
38-
self._old_device = args["device"][-1]
40+
self._multiple_devices = args["device"]
3941

4042
self._opt_name = args["optimizer"]
4143
self._lr = args["lr"]
@@ -87,6 +89,16 @@ def __init__(self, args):
8789

8890
self._epoch_metrics = collections.defaultdict(list)
8991

92+
def save_metadata(self, path):
93+
logger.info("Saving metadata at {}.".format(path))
94+
with open(path, "wb+") as f:
95+
pickle.dump(self._herding_indexes, f)
96+
97+
def load_metadata(self, path):
98+
logger.info("Loading metadata at {}.".format(path))
99+
with open(path, "rb") as f:
100+
self._herding_indexes = pickle.load(f)
101+
90102
@property
91103
def epoch_metrics(self):
92104
return dict(self._epoch_metrics)
@@ -150,6 +162,12 @@ def _training_step(self, train_loader, val_loader, initial_epoch, nb_epochs):
150162
best_epoch, best_acc = -1, -1.
151163
wait = 0
152164

165+
if len(self._multiple_devices) > 1:
166+
logger.info("Duplicating model on {} gpus.".format(len(self._multiple_devices)))
167+
training_network = nn.DataParallel(self._network, self._multiple_devices)
168+
else:
169+
training_network = self._network
170+
153171
for epoch in range(initial_epoch, nb_epochs):
154172
self._metrics = collections.defaultdict(float)
155173

@@ -166,7 +184,7 @@ def _training_step(self, train_loader, val_loader, initial_epoch, nb_epochs):
166184
memory_flags = input_dict["memory_flags"]
167185

168186
self._optimizer.zero_grad()
169-
loss = self._forward_loss(inputs, targets, memory_flags)
187+
loss = self._forward_loss(training_network, inputs, targets, memory_flags)
170188
loss.backward()
171189
self._optimizer.step()
172190

@@ -211,15 +229,15 @@ def _print_metrics(self, prog_bar, epoch, nb_epochs, nb_batches):
211229
)
212230
)
213231

214-
def _forward_loss(self, inputs, targets, memory_flags):
232+
def _forward_loss(self, training_network, inputs, targets, memory_flags):
215233
inputs, targets = inputs.to(self._device), targets.to(self._device)
216234
onehot_targets = utils.to_onehot(targets, self._n_classes).to(self._device)
217235

218236
if self._random_noise_config:
219237
random_noise = torch.randn(self._random_noise_config["nb_per_batch"], *inputs.shape[1:])
220238
inputs = torch.cat((inputs, random_noise.to(self._device)))
221239

222-
logits = self._network(inputs)
240+
logits = training_network(inputs)
223241

224242
loss = self._compute_loss(inputs, logits, targets, onehot_targets, memory_flags)
225243

@@ -235,7 +253,7 @@ def _after_task(self, inc_dataset):
235253
inc_dataset, self._herding_indexes
236254
)
237255

238-
self._old_model = self._network.copy().freeze().to(self._old_device)
256+
self._old_model = self._network.copy().freeze().to(self._device)
239257

240258
self._network.on_task_end()
241259
self.plot_tsne()
@@ -249,9 +267,8 @@ def plot_tsne(self):
249267
)
250268

251269
def _eval_task(self, data_loader):
252-
ypred, ytrue = self.compute_accuracy(self._network, data_loader, self._class_means)
253-
254-
return ypred, ytrue
270+
ypreds, ytrue = self.compute_accuracy(self._network, data_loader, self._class_means)
271+
return ypreds, ytrue
255272

256273
# -----------
257274
# Private API
@@ -261,8 +278,7 @@ def _compute_loss(self, inputs, logits, targets, onehot_targets, memory_flags):
261278
if self._old_model is None:
262279
loss = F.binary_cross_entropy_with_logits(logits, onehot_targets)
263280
else:
264-
old_targets = torch.sigmoid(self._old_model(inputs.to(self._old_device)).detach()
265-
).to(self._device)
281+
old_targets = torch.sigmoid(self._old_model(inputs).detach())
266282

267283
new_targets = onehot_targets.clone()
268284
new_targets[..., :-self._task_size] = old_targets
@@ -334,7 +350,7 @@ def build_examplars(
334350
herding_indexes = copy.deepcopy(herding_indexes)
335351

336352
data_memory, targets_memory = [], []
337-
class_means = np.zeros((100, self._network.features_dim))
353+
class_means = np.zeros((self._n_classes, self._network.features_dim))
338354

339355
for class_idx in range(self._n_classes):
340356
# We extract the features, both normal and flipped:
@@ -410,4 +426,4 @@ def compute_accuracy(model, loader, class_means):
410426
sqd = cdist(class_means, features, 'sqeuclidean')
411427
score_icarl = (-sqd).T
412428

413-
return np.argsort(score_icarl, axis=1)[:, -1], targets_
429+
return score_icarl, targets_

0 commit comments

Comments
 (0)