-
Notifications
You must be signed in to change notification settings - Fork 95
/
Copy pathfusion.py
328 lines (277 loc) · 11.1 KB
/
fusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
"""
In fusion-based ensemble, predictions from all base estimators are
first aggregated as an average output. After then, the training loss is
computed based on this average output and the ground-truth. The training
loss is then back-propagated to all base estimators simultaneously.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from ._base import BaseClassifier, BaseRegressor
from ._base import torchensemble_model_doc
from .utils import io
from .utils import set_module
from .utils import operator as op
__all__ = ["FusionClassifier", "FusionRegressor"]
@torchensemble_model_doc(
"""Implementation on the FusionClassifier.""", "model"
)
class FusionClassifier(BaseClassifier):
def _forward(self, *x):
"""
Implementation on the internal data forwarding in FusionClassifier.
"""
# Average
outputs = [estimator(*x) for estimator in self.estimators_]
output = op.average(outputs)
return output
@torchensemble_model_doc(
"""Implementation on the data forwarding in FusionClassifier.""",
"classifier_forward",
)
def forward(self, *x):
output = op.unsqueeze_tensor(self._forward(*x))
proba = F.softmax(output, dim=1)
return proba
@torchensemble_model_doc(
"""Set the attributes on optimizer for FusionClassifier.""",
"set_optimizer",
)
def set_optimizer(self, optimizer_name, **kwargs):
super().set_optimizer(optimizer_name, **kwargs)
@torchensemble_model_doc(
"""Set the attributes on scheduler for FusionClassifier.""",
"set_scheduler",
)
def set_scheduler(self, scheduler_name, **kwargs):
super().set_scheduler(scheduler_name, **kwargs)
@torchensemble_model_doc(
"""Set the training criterion for FusionClassifier.""",
"set_criterion",
)
def set_criterion(self, criterion):
super().set_criterion(criterion)
@torchensemble_model_doc(
"""Implementation on the training stage of FusionClassifier.""", "fit"
)
def fit(
self,
train_loader,
epochs=100,
log_interval=100,
test_loader=None,
save_model=True,
save_dir=None,
):
# Instantiate base estimators and set attributes
for _ in range(self.n_estimators):
self.estimators_.append(self._make_estimator())
self._validate_parameters(epochs, log_interval)
self.n_outputs = self._decide_n_outputs(train_loader)
optimizer = set_module.set_optimizer(
self, self.optimizer_name, **self.optimizer_args
)
# Set the scheduler if `set_scheduler` was called before
if self.use_scheduler_:
self.scheduler_ = set_module.set_scheduler(
optimizer, self.scheduler_name, **self.scheduler_args
)
# Check the training criterion
if not hasattr(self, "_criterion"):
self._criterion = nn.CrossEntropyLoss()
# Utils
best_acc = 0.0
total_iters = 0
# Training loop
for epoch in range(epochs):
self.train()
for batch_idx, elem in enumerate(train_loader):
data, target = io.split_data_target(elem, self.device)
batch_size = data[0].size(0)
optimizer.zero_grad()
output = self._forward(*data)
loss = self._criterion(output, target)
loss.backward()
optimizer.step()
# Print training status
if batch_idx % log_interval == 0:
with torch.no_grad():
_, predicted = torch.max(output.data, 1)
correct = (predicted == target).sum().item()
msg = (
"Epoch: {:03d} | Batch: {:03d} | Loss:"
" {:.5f} | Correct: {:d}/{:d}"
)
self.logger.info(
msg.format(
epoch, batch_idx, loss, correct, batch_size
)
)
if self.tb_logger:
self.tb_logger.add_scalar(
"fusion/Train_Loss", loss, total_iters
)
total_iters += 1
# Validation
if test_loader:
self.eval()
with torch.no_grad():
correct = 0
total = 0
for _, elem in enumerate(test_loader):
data, target = io.split_data_target(elem, self.device)
output = self.forward(*data)
_, predicted = torch.max(output.data, 1)
correct += (predicted == target).sum().item()
total += target.size(0)
acc = 100 * correct / total
if acc > best_acc:
best_acc = acc
if save_model:
io.save(self, save_dir, self.logger)
msg = (
"Epoch: {:03d} | Validation Acc: {:.3f}"
" % | Historical Best: {:.3f} %"
)
self.logger.info(msg.format(epoch, acc, best_acc))
if self.tb_logger:
self.tb_logger.add_scalar(
"fusion/Validation_Acc", acc, epoch
)
# Update the scheduler
if hasattr(self, "scheduler_"):
if self.scheduler_name == "ReduceLROnPlateau":
if test_loader:
self.scheduler_.step(acc)
else:
self.scheduler_.step(loss)
else:
self.scheduler_.step()
if save_model and not test_loader:
io.save(self, save_dir, self.logger)
@torchensemble_model_doc(item="classifier_evaluate")
def evaluate(self, test_loader, return_loss=False):
return super().evaluate(test_loader, return_loss)
@torchensemble_model_doc(item="predict")
def predict(self, *x):
return super().predict(*x)
@torchensemble_model_doc("""Implementation on the FusionRegressor.""", "model")
class FusionRegressor(BaseRegressor):
@torchensemble_model_doc(
"""Implementation on the data forwarding in FusionRegressor.""",
"regressor_forward",
)
def forward(self, *x):
# Average
outputs = [estimator(*x) for estimator in self.estimators_]
pred = op.average(outputs)
return pred
@torchensemble_model_doc(
"""Set the attributes on optimizer for FusionRegressor.""",
"set_optimizer",
)
def set_optimizer(self, optimizer_name, **kwargs):
super().set_optimizer(optimizer_name, **kwargs)
@torchensemble_model_doc(
"""Set the attributes on scheduler for FusionRegressor.""",
"set_scheduler",
)
def set_scheduler(self, scheduler_name, **kwargs):
super().set_scheduler(scheduler_name, **kwargs)
@torchensemble_model_doc(
"""Set the training criterion for FusionRegressor.""",
"set_criterion",
)
def set_criterion(self, criterion):
super().set_criterion(criterion)
@torchensemble_model_doc(
"""Implementation on the training stage of FusionRegressor.""", "fit"
)
def fit(
self,
train_loader,
epochs=100,
log_interval=100,
test_loader=None,
save_model=True,
save_dir=None,
):
# Instantiate base estimators and set attributes
for _ in range(self.n_estimators):
self.estimators_.append(self._make_estimator())
self._validate_parameters(epochs, log_interval)
self.n_outputs = self._decide_n_outputs(train_loader)
optimizer = set_module.set_optimizer(
self, self.optimizer_name, **self.optimizer_args
)
# Set the scheduler if `set_scheduler` was called before
if self.use_scheduler_:
self.scheduler_ = set_module.set_scheduler(
optimizer, self.scheduler_name, **self.scheduler_args
)
# Check the training criterion
if not hasattr(self, "_criterion"):
self._criterion = nn.MSELoss()
# Utils
best_loss = float("inf")
total_iters = 0
# Training loop
for epoch in range(epochs):
self.train()
for batch_idx, elem in enumerate(train_loader):
data, target = io.split_data_target(elem, self.device)
optimizer.zero_grad()
output = self.forward(*data)
loss = self._criterion(output, target)
loss.backward()
optimizer.step()
# Print training status
if batch_idx % log_interval == 0:
with torch.no_grad():
msg = "Epoch: {:03d} | Batch: {:03d} | Loss: {:.5f}"
self.logger.info(msg.format(epoch, batch_idx, loss))
if self.tb_logger:
self.tb_logger.add_scalar(
"fusion/Train_Loss", loss, total_iters
)
total_iters += 1
# Validation
if test_loader:
self.eval()
with torch.no_grad():
val_loss = 0.0
for _, elem in enumerate(test_loader):
data, target = io.split_data_target(elem, self.device)
output = self.forward(*data)
val_loss += self._criterion(output, target)
val_loss /= len(test_loader)
if val_loss < best_loss:
best_loss = val_loss
if save_model:
io.save(self, save_dir, self.logger)
msg = (
"Epoch: {:03d} | Validation Loss: {:.5f} |"
" Historical Best: {:.5f}"
)
self.logger.info(msg.format(epoch, val_loss, best_loss))
if self.tb_logger:
self.tb_logger.add_scalar(
"fusion/Validation_Loss", val_loss, epoch
)
# Update the scheduler
if hasattr(self, "scheduler_"):
if self.scheduler_name == "ReduceLROnPlateau":
if test_loader:
self.scheduler_.step(val_loss)
else:
self.scheduler_.step(loss)
else:
self.scheduler_.step()
if save_model and not test_loader:
io.save(self, save_dir, self.logger)
@torchensemble_model_doc(item="regressor_evaluate")
def evaluate(self, test_loader):
return super().evaluate(test_loader)
@torchensemble_model_doc(item="predict")
def predict(self, *x):
return super().predict(*x)