Skip to content

Commit e3fb9fe

Browse files
authored
fix: estimator overwrite problem with validation data loader (#143)
* update code * Update index.rst
1 parent a89cfb3 commit e3fb9fe

File tree

5 files changed

+41
-32
lines changed

5 files changed

+41
-32
lines changed

Diff for: CHANGELOG.rst

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Changelog
1818
Ver 0.1.*
1919
---------
2020

21+
* |Fix| Fix the issue that model will be overwritten when using the validation data loder | `@xuyxu <https://github.com/xuyxu>`__
2122
* |Feature| Add internal :meth:`unsqueeze` operation in :meth:`forward` of all classifiers | `@xuyxu <https://github.com/xuyxu>`__
2223
* |Feature| |API| Add ``voting_strategy`` parameter for :class:`VotingClassifer`, :class:`NeuralForestClassifier`, and :class:`SnapshotEnsembleClassifier` | `@LukasGardberg <https://github.com/LukasGardberg>`__
2324
* |Fix| Fix the sampling issue in :class:`BaggingClassifier` and :class:`BaggingRegressor` | `@SunHaozhe <https://github.com/SunHaozhe>`__

Diff for: docs/index.rst

+4-2
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,9 @@ Content
6464
-------
6565

6666
.. toctree::
67-
:maxdepth: 1
6867
:caption: For Users
68+
:maxdepth: 2
69+
:hidden:
6970

7071
Quick Start <quick_start>
7172
Introduction <introduction>
@@ -74,8 +75,9 @@ Content
7475
API Reference <parameters>
7576

7677
.. toctree::
77-
:maxdepth: 1
7878
:caption: For Developers
79+
:maxdepth: 2
80+
:hidden:
7981

8082
Changelog <changelog>
8183
Roadmap <roadmap>

Diff for: torchensemble/adversarial_training.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,12 @@ def _forward(estimators, *x):
384384
acc,
385385
epoch,
386386
)
387+
# No validation
388+
else:
389+
self.estimators_ = nn.ModuleList()
390+
self.estimators_.extend(estimators)
391+
if save_model:
392+
io.save(self, save_dir, self.logger)
387393

388394
# Update the scheduler
389395
with warnings.catch_warnings():
@@ -402,11 +408,6 @@ def _forward(estimators, *x):
402408
else:
403409
scheduler_.step()
404410

405-
self.estimators_ = nn.ModuleList()
406-
self.estimators_.extend(estimators)
407-
if save_model and not test_loader:
408-
io.save(self, save_dir, self.logger)
409-
410411
@torchensemble_model_doc(item="classifier_evaluate")
411412
def evaluate(self, test_loader, return_loss=False):
412413
return super().evaluate(test_loader, return_loss)
@@ -580,6 +581,12 @@ def _forward(estimators, *x):
580581
val_loss,
581582
epoch,
582583
)
584+
# No validation
585+
else:
586+
self.estimators_ = nn.ModuleList()
587+
self.estimators_.extend(estimators)
588+
if save_model:
589+
io.save(self, save_dir, self.logger)
583590

584591
# Update the scheduler
585592
with warnings.catch_warnings():
@@ -595,11 +602,6 @@ def _forward(estimators, *x):
595602
else:
596603
scheduler_.step()
597604

598-
self.estimators_ = nn.ModuleList()
599-
self.estimators_.extend(estimators)
600-
if save_model and not test_loader:
601-
io.save(self, save_dir, self.logger)
602-
603605
@torchensemble_model_doc(item="regressor_evaluate")
604606
def evaluate(self, test_loader):
605607
return super().evaluate(test_loader)

Diff for: torchensemble/bagging.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,12 @@ def _forward(estimators, *x):
253253
self.tb_logger.add_scalar(
254254
"bagging/Validation_Acc", acc, epoch
255255
)
256+
# No validation
257+
else:
258+
self.estimators_ = nn.ModuleList()
259+
self.estimators_.extend(estimators)
260+
if save_model:
261+
io.save(self, save_dir, self.logger)
256262

257263
# Update the scheduler
258264
with warnings.catch_warnings():
@@ -271,11 +277,6 @@ def _forward(estimators, *x):
271277
else:
272278
scheduler_.step()
273279

274-
self.estimators_ = nn.ModuleList()
275-
self.estimators_.extend(estimators)
276-
if save_model and not test_loader:
277-
io.save(self, save_dir, self.logger)
278-
279280
@torchensemble_model_doc(item="classifier_evaluate")
280281
def evaluate(self, test_loader, return_loss=False):
281282
return super().evaluate(test_loader, return_loss)
@@ -449,6 +450,12 @@ def _forward(estimators, *x):
449450
self.tb_logger.add_scalar(
450451
"bagging/Validation_Loss", val_loss, epoch
451452
)
453+
# No validation
454+
else:
455+
self.estimators_ = nn.ModuleList()
456+
self.estimators_.extend(estimators)
457+
if save_model:
458+
io.save(self, save_dir, self.logger)
452459

453460
# Update the scheduler
454461
with warnings.catch_warnings():
@@ -464,11 +471,6 @@ def _forward(estimators, *x):
464471
else:
465472
scheduler_.step()
466473

467-
self.estimators_ = nn.ModuleList()
468-
self.estimators_.extend(estimators)
469-
if save_model and not test_loader:
470-
io.save(self, save_dir, self.logger)
471-
472474
@torchensemble_model_doc(item="regressor_evaluate")
473475
def evaluate(self, test_loader):
474476
return super().evaluate(test_loader)

Diff for: torchensemble/voting.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,12 @@ def _forward(estimators, *x):
276276
self.tb_logger.add_scalar(
277277
"voting/Validation_Acc", acc, epoch
278278
)
279+
# No validation
280+
else:
281+
self.estimators_ = nn.ModuleList()
282+
self.estimators_.extend(estimators)
283+
if save_model:
284+
io.save(self, save_dir, self.logger)
279285

280286
# Update the scheduler
281287
with warnings.catch_warnings():
@@ -294,11 +300,6 @@ def _forward(estimators, *x):
294300
else:
295301
scheduler_.step()
296302

297-
self.estimators_ = nn.ModuleList()
298-
self.estimators_.extend(estimators)
299-
if save_model and not test_loader:
300-
io.save(self, save_dir, self.logger)
301-
302303
@torchensemble_model_doc(item="classifier_evaluate")
303304
def evaluate(self, test_loader, return_loss=False):
304305
return super().evaluate(test_loader, return_loss)
@@ -532,6 +533,12 @@ def _forward(estimators, *x):
532533
self.tb_logger.add_scalar(
533534
"voting/Validation_Loss", val_loss, epoch
534535
)
536+
# No validation
537+
else:
538+
self.estimators_ = nn.ModuleList()
539+
self.estimators_.extend(estimators)
540+
if save_model:
541+
io.save(self, save_dir, self.logger)
535542

536543
# Update the scheduler
537544
with warnings.catch_warnings():
@@ -547,11 +554,6 @@ def _forward(estimators, *x):
547554
else:
548555
scheduler_.step()
549556

550-
self.estimators_ = nn.ModuleList()
551-
self.estimators_.extend(estimators)
552-
if save_model and not test_loader:
553-
io.save(self, save_dir, self.logger)
554-
555557
@torchensemble_model_doc(item="regressor_evaluate")
556558
def evaluate(self, test_loader):
557559
return super().evaluate(test_loader)

0 commit comments

Comments
 (0)