-
Notifications
You must be signed in to change notification settings - Fork 95
/
Copy pathsnapshot_ensemble.py
556 lines (462 loc) · 19.4 KB
/
snapshot_ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
"""
Snapshot ensemble generates many base estimators by enforcing a base
estimator to converge to its local minima many times and save the
model parameters at that point as a snapshot. The final prediction takes
the average over predictions from all snapshot models.
Reference:
G. Huang, Y.-X. Li, G. Pleiss et al., Snapshot Ensemble: Train 1, and
M for free, ICLR, 2017.
"""
import math
import torch
import logging
import warnings
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
from ._base import BaseModule, BaseClassifier, BaseRegressor
from ._base import torchensemble_model_doc
from .utils import io
from .utils import set_module
from .utils import operator as op
from .utils.logging import get_tb_logger
__all__ = ["SnapshotEnsembleClassifier", "SnapshotEnsembleRegressor"]
__fit_doc = """
Parameters
----------
train_loader : torch.utils.data.DataLoader
A :mod:`DataLoader` container that contains the training data.
lr_clip : list or tuple, default=None
Specify the accepted range of learning rate. When the learning rate
determined by the scheduler is out of this range, it will be clipped.
- The first element should be the lower bound of learning rate.
- The second element should be the upper bound of learning rate.
epochs : int, default=100
The number of training epochs.
log_interval : int, default=100
The number of batches to wait before logging the training status.
test_loader : torch.utils.data.DataLoader, default=None
A :mod:`DataLoader` container that contains the evaluating data.
- If ``None``, no validation is conducted after each snapshot
being generated.
- If not ``None``, the ensemble will be evaluated on this
dataloader after each snapshot being generated.
save_model : bool, default=True
Specify whether to save the model parameters.
- If test_loader is ``None``, the ensemble with
``n_estimators`` base estimators will be saved.
- If test_loader is not ``None``, the ensemble with the best
validation performance will be saved.
save_dir : string, default=None
Specify where to save the model parameters.
- If ``None``, the model will be saved in the current directory.
- If not ``None``, the model will be saved in the specified
directory: ``save_dir``.
"""
def _snapshot_ensemble_model_doc(header, item="fit"):
"""
Decorator on obtaining documentation for different snapshot ensemble
models.
"""
def get_doc(item):
"""Return selected item"""
__doc = {"fit": __fit_doc}
return __doc[item]
def adddoc(cls):
doc = [header + "\n\n"]
doc.extend(get_doc(item))
cls.__doc__ = "".join(doc)
return cls
return adddoc
class _BaseSnapshotEnsemble(BaseModule):
def __init__(
self, estimator, n_estimators, estimator_args=None, cuda=True
):
super(BaseModule, self).__init__()
self.base_estimator_ = estimator
self.n_estimators = n_estimators
self.estimator_args = estimator_args
if estimator_args and not isinstance(estimator, type):
msg = (
"The input `estimator_args` will have no effect since"
" `estimator` is already an object after instantiation."
)
warnings.warn(msg, RuntimeWarning)
self.device = torch.device("cuda" if cuda else "cpu")
self.logger = logging.getLogger()
self.tb_logger = get_tb_logger()
self.estimators_ = nn.ModuleList()
def _validate_parameters(self, lr_clip, epochs, log_interval):
"""Validate hyper-parameters on training the ensemble."""
if lr_clip:
if not (isinstance(lr_clip, list) or isinstance(lr_clip, tuple)):
msg = "lr_clip should be a list or tuple with two elements."
self.logger.error(msg)
raise ValueError(msg)
if len(lr_clip) != 2:
msg = (
"lr_clip should only have two elements, one for lower"
" bound, and another for upper bound."
)
self.logger.error(msg)
raise ValueError(msg)
if not lr_clip[0] < lr_clip[1]:
msg = (
"The first element = {} should be smaller than the"
" second element = {} in lr_clip."
)
self.logger.error(msg.format(lr_clip[0], lr_clip[1]))
raise ValueError(msg.format(lr_clip[0], lr_clip[1]))
if not epochs > 0:
msg = (
"The number of training epochs = {} should be strictly"
" positive."
)
self.logger.error(msg.format(epochs))
raise ValueError(msg.format(epochs))
if not log_interval > 0:
msg = (
"The number of batches to wait before printting the"
" training status should be strictly positive, but got {}"
" instead."
)
self.logger.error(msg.format(log_interval))
raise ValueError(msg.format(log_interval))
if not epochs % self.n_estimators == 0:
msg = (
"The number of training epochs = {} should be a multiple"
" of n_estimators = {}."
)
self.logger.error(msg.format(epochs, self.n_estimators))
raise ValueError(msg.format(epochs, self.n_estimators))
def _forward(self, *x):
"""
Implementation on the internal data forwarding in snapshot ensemble.
"""
# Average
results = [estimator(*x) for estimator in self.estimators_]
output = op.average(results)
return output
def _clip_lr(self, optimizer, lr_clip):
"""Clip the learning rate of the optimizer according to `lr_clip`."""
if not lr_clip:
return optimizer
for param_group in optimizer.param_groups:
if param_group["lr"] < lr_clip[0]:
param_group["lr"] = lr_clip[0]
if param_group["lr"] > lr_clip[1]:
param_group["lr"] = lr_clip[1]
return optimizer
def _set_scheduler(self, optimizer, n_iters):
"""
Set the learning rate scheduler for snapshot ensemble.
Please refer to the equation (2) in original paper for details.
"""
T_M = math.ceil(n_iters / self.n_estimators)
lr_lambda = lambda iteration: 0.5 * ( # noqa: E731
torch.cos(torch.tensor(math.pi * (iteration % T_M) / T_M)) + 1
)
scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
return scheduler
def set_scheduler(self, scheduler_name, **kwargs):
msg = (
"The learning rate scheduler for Snapshot Ensemble will be"
" automatically set. Calling this function has no effect on"
" the training stage of Snapshot Ensemble."
)
warnings.warn(msg, RuntimeWarning)
@torchensemble_model_doc(
"""Implementation on the SnapshotEnsembleClassifier.""", "seq_model"
)
class SnapshotEnsembleClassifier(_BaseSnapshotEnsemble, BaseClassifier):
def __init__(self, voting_strategy="soft", **kwargs):
super().__init__(**kwargs)
implemented_strategies = {"soft", "hard"}
if voting_strategy not in implemented_strategies:
msg = (
"Voting strategy {} is not implemented, "
"please choose from {}."
)
raise ValueError(
msg.format(voting_strategy, implemented_strategies)
)
self.voting_strategy = voting_strategy
@torchensemble_model_doc(
"""Implementation on the data forwarding in SnapshotEnsembleClassifier.""", # noqa: E501
"classifier_forward",
)
def forward(self, *x):
outputs = [
F.softmax(op.unsqueeze_tensor(estimator(*x)), dim=1)
for estimator in self.estimators_
]
if self.voting_strategy == "soft":
proba = op.average(outputs)
else:
proba = op.majority_vote(outputs)
return proba
@torchensemble_model_doc(
"""Set the attributes on optimizer for SnapshotEnsembleClassifier.""",
"set_optimizer",
)
def set_optimizer(self, optimizer_name, **kwargs):
super().set_optimizer(optimizer_name, **kwargs)
@torchensemble_model_doc(
"""Set the training criterion for SnapshotEnsembleClassifier.""",
"set_criterion",
)
def set_criterion(self, criterion):
super().set_criterion(criterion)
@_snapshot_ensemble_model_doc(
"""Implementation on the training stage of SnapshotEnsembleClassifier.""", # noqa: E501
"fit",
)
def fit(
self,
train_loader,
lr_clip=None,
epochs=100,
log_interval=100,
test_loader=None,
save_model=True,
save_dir=None,
):
self._validate_parameters(lr_clip, epochs, log_interval)
self.n_outputs = self._decide_n_outputs(train_loader)
estimator = self._make_estimator()
# Set the optimizer and scheduler
optimizer = set_module.set_optimizer(
estimator, self.optimizer_name, **self.optimizer_args
)
scheduler = self._set_scheduler(optimizer, epochs * len(train_loader))
# Check the training criterion
if not hasattr(self, "_criterion"):
self._criterion = nn.CrossEntropyLoss()
# Utils
best_acc = 0.0
counter = 0 # a counter on generating snapshots
total_iters = 0
n_iters_per_estimator = epochs * len(train_loader) // self.n_estimators
# Training loop
estimator.train()
for epoch in range(epochs):
for batch_idx, elem in enumerate(train_loader):
data, target = io.split_data_target(elem, self.device)
batch_size = data[0].size(0)
# Clip the learning rate
optimizer = self._clip_lr(optimizer, lr_clip)
optimizer.zero_grad()
output = estimator(*data)
loss = self._criterion(output, target)
loss.backward()
optimizer.step()
# Print training status
if batch_idx % log_interval == 0:
with torch.no_grad():
_, predicted = torch.max(output.data, 1)
correct = (predicted == target).sum().item()
msg = (
"lr: {:.5f} | Epoch: {:03d} | Batch: {:03d} |"
" Loss: {:.5f} | Correct: {:d}/{:d}"
)
self.logger.info(
msg.format(
optimizer.param_groups[0]["lr"],
epoch,
batch_idx,
loss,
correct,
batch_size,
)
)
if self.tb_logger:
self.tb_logger.add_scalar(
"snapshot_ensemble/Train_Loss",
loss,
total_iters,
)
else:
print("None")
# Snapshot ensemble updates the learning rate per iteration
# instead of per epoch.
scheduler.step()
counter += 1
total_iters += 1
if counter % n_iters_per_estimator == 0:
# Generate and save the snapshot
snapshot = self._make_estimator()
snapshot.load_state_dict(estimator.state_dict())
self.estimators_.append(snapshot)
msg = "Save the snapshot model with index: {}"
self.logger.info(msg.format(len(self.estimators_) - 1))
# Validation after each snapshot model being generated
if test_loader and counter % n_iters_per_estimator == 0:
self.eval()
with torch.no_grad():
correct = 0
total = 0
for _, elem in enumerate(test_loader):
data, target = io.split_data_target(elem, self.device)
output = self.forward(*data)
_, predicted = torch.max(output.data, 1)
correct += (predicted == target).sum().item()
total += target.size(0)
acc = 100 * correct / total
if acc > best_acc:
best_acc = acc
if save_model:
io.save(self, save_dir, self.logger)
msg = (
"n_estimators: {} | Validation Acc: {:.3f} %"
" | Historical Best: {:.3f} %"
)
self.logger.info(
msg.format(len(self.estimators_), acc, best_acc)
)
if self.tb_logger:
self.tb_logger.add_scalar(
"snapshot_ensemble/Validation_Acc",
acc,
len(self.estimators_),
)
if save_model and not test_loader:
io.save(self, save_dir, self.logger)
@torchensemble_model_doc(item="classifier_evaluate")
def evaluate(self, test_loader, return_loss=False):
return super().evaluate(test_loader, return_loss)
@torchensemble_model_doc(item="predict")
def predict(self, *x):
return super().predict(*x)
@torchensemble_model_doc(
"""Implementation on the SnapshotEnsembleRegressor.""", "seq_model"
)
class SnapshotEnsembleRegressor(_BaseSnapshotEnsemble, BaseRegressor):
@torchensemble_model_doc(
"""Implementation on the data forwarding in SnapshotEnsembleRegressor.""", # noqa: E501
"regressor_forward",
)
def forward(self, *x):
pred = self._forward(*x)
return pred
@torchensemble_model_doc(
"""Set the attributes on optimizer for SnapshotEnsembleRegressor.""",
"set_optimizer",
)
def set_optimizer(self, optimizer_name, **kwargs):
super().set_optimizer(optimizer_name, **kwargs)
@torchensemble_model_doc(
"""Set the training criterion for SnapshotEnsembleRegressor.""",
"set_criterion",
)
def set_criterion(self, criterion):
super().set_criterion(criterion)
@_snapshot_ensemble_model_doc(
"""Implementation on the training stage of SnapshotEnsembleRegressor.""", # noqa: E501
"fit",
)
def fit(
self,
train_loader,
lr_clip=None,
epochs=100,
log_interval=100,
test_loader=None,
save_model=True,
save_dir=None,
):
self._validate_parameters(lr_clip, epochs, log_interval)
self.n_outputs = self._decide_n_outputs(train_loader)
estimator = self._make_estimator()
# Set the optimizer and scheduler
optimizer = set_module.set_optimizer(
estimator, self.optimizer_name, **self.optimizer_args
)
scheduler = self._set_scheduler(optimizer, epochs * len(train_loader))
# Check the training criterion
if not hasattr(self, "_criterion"):
self._criterion = nn.MSELoss()
# Utils
best_loss = float("inf")
counter = 0 # a counter on generating snapshots
total_iters = 0
n_iters_per_estimator = epochs * len(train_loader) // self.n_estimators
# Training loop
estimator.train()
for epoch in range(epochs):
for batch_idx, elem in enumerate(train_loader):
data, target = io.split_data_target(elem, self.device)
# Clip the learning rate
optimizer = self._clip_lr(optimizer, lr_clip)
optimizer.zero_grad()
output = estimator(*data)
loss = self._criterion(output, target)
loss.backward()
optimizer.step()
# Print training status
if batch_idx % log_interval == 0:
with torch.no_grad():
msg = (
"lr: {:.5f} | Epoch: {:03d} | Batch: {:03d}"
" | Loss: {:.5f}"
)
self.logger.info(
msg.format(
optimizer.param_groups[0]["lr"],
epoch,
batch_idx,
loss,
)
)
if self.tb_logger:
self.tb_logger.add_scalar(
"snapshot_ensemble/Train_Loss",
loss,
total_iters,
)
# Snapshot ensemble updates the learning rate per iteration
# instead of per epoch.
scheduler.step()
counter += 1
total_iters += 1
if counter % n_iters_per_estimator == 0:
# Generate and save the snapshot
snapshot = self._make_estimator()
snapshot.load_state_dict(estimator.state_dict())
self.estimators_.append(snapshot)
msg = "Save the snapshot model with index: {}"
self.logger.info(msg.format(len(self.estimators_) - 1))
# Validation after each snapshot model being generated
if test_loader and counter % n_iters_per_estimator == 0:
self.eval()
with torch.no_grad():
val_loss = 0.0
for _, elem in enumerate(test_loader):
data, target = io.split_data_target(elem, self.device)
output = self.forward(*data)
val_loss += self._criterion(output, target)
val_loss /= len(test_loader)
if val_loss < best_loss:
best_loss = val_loss
if save_model:
io.save(self, save_dir, self.logger)
msg = (
"n_estimators: {} | Validation Loss: {:.5f} |"
" Historical Best: {:.5f}"
)
self.logger.info(
msg.format(len(self.estimators_), val_loss, best_loss)
)
if self.tb_logger:
self.tb_logger.add_scalar(
"snapshot_ensemble/Validation_Loss",
val_loss,
len(self.estimators_),
)
if save_model and not test_loader:
io.save(self, save_dir, self.logger)
@torchensemble_model_doc(item="regressor_evaluate")
def evaluate(self, test_loader):
return super().evaluate(test_loader)
@torchensemble_model_doc(item="predict")
def predict(self, *x):
return super().predict(*x)