3
3
import logging
4
4
import os
5
5
import pdb
6
+ import pickle
6
7
7
8
import numpy as np
8
9
import torch
9
10
from scipy .spatial .distance import cdist
11
+ from torch import nn
10
12
from torch .nn import functional as F
11
13
from tqdm import tqdm
12
14
@@ -35,7 +37,7 @@ def __init__(self, args):
35
37
self ._disable_progressbar = args .get ("no_progressbar" , False )
36
38
37
39
self ._device = args ["device" ][0 ]
38
- self ._old_device = args ["device" ][ - 1 ]
40
+ self ._multiple_devices = args ["device" ]
39
41
40
42
self ._opt_name = args ["optimizer" ]
41
43
self ._lr = args ["lr" ]
@@ -87,6 +89,16 @@ def __init__(self, args):
87
89
88
90
self ._epoch_metrics = collections .defaultdict (list )
89
91
92
+ def save_metadata (self , path ):
93
+ logger .info ("Saving metadata at {}." .format (path ))
94
+ with open (path , "wb+" ) as f :
95
+ pickle .dump (self ._herding_indexes , f )
96
+
97
+ def load_metadata (self , path ):
98
+ logger .info ("Loading metadata at {}." .format (path ))
99
+ with open (path , "rb" ) as f :
100
+ self ._herding_indexes = pickle .load (f )
101
+
90
102
@property
91
103
def epoch_metrics (self ):
92
104
return dict (self ._epoch_metrics )
@@ -150,6 +162,12 @@ def _training_step(self, train_loader, val_loader, initial_epoch, nb_epochs):
150
162
best_epoch , best_acc = - 1 , - 1.
151
163
wait = 0
152
164
165
+ if len (self ._multiple_devices ) > 1 :
166
+ logger .info ("Duplicating model on {} gpus." .format (len (self ._multiple_devices )))
167
+ training_network = nn .DataParallel (self ._network , self ._multiple_devices )
168
+ else :
169
+ training_network = self ._network
170
+
153
171
for epoch in range (initial_epoch , nb_epochs ):
154
172
self ._metrics = collections .defaultdict (float )
155
173
@@ -166,7 +184,7 @@ def _training_step(self, train_loader, val_loader, initial_epoch, nb_epochs):
166
184
memory_flags = input_dict ["memory_flags" ]
167
185
168
186
self ._optimizer .zero_grad ()
169
- loss = self ._forward_loss (inputs , targets , memory_flags )
187
+ loss = self ._forward_loss (training_network , inputs , targets , memory_flags )
170
188
loss .backward ()
171
189
self ._optimizer .step ()
172
190
@@ -211,15 +229,15 @@ def _print_metrics(self, prog_bar, epoch, nb_epochs, nb_batches):
211
229
)
212
230
)
213
231
214
- def _forward_loss (self , inputs , targets , memory_flags ):
232
+ def _forward_loss (self , training_network , inputs , targets , memory_flags ):
215
233
inputs , targets = inputs .to (self ._device ), targets .to (self ._device )
216
234
onehot_targets = utils .to_onehot (targets , self ._n_classes ).to (self ._device )
217
235
218
236
if self ._random_noise_config :
219
237
random_noise = torch .randn (self ._random_noise_config ["nb_per_batch" ], * inputs .shape [1 :])
220
238
inputs = torch .cat ((inputs , random_noise .to (self ._device )))
221
239
222
- logits = self . _network (inputs )
240
+ logits = training_network (inputs )
223
241
224
242
loss = self ._compute_loss (inputs , logits , targets , onehot_targets , memory_flags )
225
243
@@ -235,7 +253,7 @@ def _after_task(self, inc_dataset):
235
253
inc_dataset , self ._herding_indexes
236
254
)
237
255
238
- self ._old_model = self ._network .copy ().freeze ().to (self ._old_device )
256
+ self ._old_model = self ._network .copy ().freeze ().to (self ._device )
239
257
240
258
self ._network .on_task_end ()
241
259
self .plot_tsne ()
@@ -249,9 +267,8 @@ def plot_tsne(self):
249
267
)
250
268
251
269
def _eval_task (self , data_loader ):
252
- ypred , ytrue = self .compute_accuracy (self ._network , data_loader , self ._class_means )
253
-
254
- return ypred , ytrue
270
+ ypreds , ytrue = self .compute_accuracy (self ._network , data_loader , self ._class_means )
271
+ return ypreds , ytrue
255
272
256
273
# -----------
257
274
# Private API
@@ -261,8 +278,7 @@ def _compute_loss(self, inputs, logits, targets, onehot_targets, memory_flags):
261
278
if self ._old_model is None :
262
279
loss = F .binary_cross_entropy_with_logits (logits , onehot_targets )
263
280
else :
264
- old_targets = torch .sigmoid (self ._old_model (inputs .to (self ._old_device )).detach ()
265
- ).to (self ._device )
281
+ old_targets = torch .sigmoid (self ._old_model (inputs ).detach ())
266
282
267
283
new_targets = onehot_targets .clone ()
268
284
new_targets [..., :- self ._task_size ] = old_targets
@@ -334,7 +350,7 @@ def build_examplars(
334
350
herding_indexes = copy .deepcopy (herding_indexes )
335
351
336
352
data_memory , targets_memory = [], []
337
- class_means = np .zeros ((100 , self ._network .features_dim ))
353
+ class_means = np .zeros ((self . _n_classes , self ._network .features_dim ))
338
354
339
355
for class_idx in range (self ._n_classes ):
340
356
# We extract the features, both normal and flipped:
@@ -410,4 +426,4 @@ def compute_accuracy(model, loader, class_means):
410
426
sqd = cdist (class_means , features , 'sqeuclidean' )
411
427
score_icarl = (- sqd ).T
412
428
413
- return np . argsort ( score_icarl , axis = 1 )[:, - 1 ] , targets_
429
+ return score_icarl , targets_
0 commit comments