Skip to content

Commit 4366765

Browse files
fix: bug in doubling estimators_ if model loaded from save_dir + feature for having callback after each epoch (#166)
* fix: dont instantiate if estimators loaded from save_dir * feat: add on_epoch_end_cb to call at end of each epoch --------- Co-authored-by: [email protected] <[email protected]>
1 parent 6726a99 commit 4366765

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

torchensemble/soft_gradient_boosting.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,14 @@ def fit(
209209
test_loader=None,
210210
save_model=True,
211211
save_dir=None,
212+
on_epoch_end_cb=None
212213
):
213214

214215
# Instantiate base estimators and set attributes
215-
for _ in range(self.n_estimators):
216-
self.estimators_.append(self._make_estimator())
216+
# dont instantiate if estimators loaded from save_dir
217+
if len(self.estimators_) != self.n_estimators:
218+
for _ in range(self.n_estimators):
219+
self.estimators_.append(self._make_estimator())
217220
self._validate_parameters(epochs, log_interval)
218221
self.n_outputs = self._decide_n_outputs(train_loader)
219222

@@ -295,6 +298,9 @@ def fit(
295298
else:
296299
scheduler.step()
297300

301+
# Call on epoch end
302+
if on_epoch_end_cb:
303+
on_epoch_end_cb(epoch)
298304
if save_model and not test_loader:
299305
io.save(self, save_dir, self.logger)
300306

@@ -390,6 +396,7 @@ def fit(
390396
test_loader=None,
391397
save_model=True,
392398
save_dir=None,
399+
on_epoch_end_cb=None
393400
):
394401
super().fit(
395402
train_loader=train_loader,
@@ -399,6 +406,7 @@ def fit(
399406
test_loader=test_loader,
400407
save_model=save_model,
401408
save_dir=save_dir,
409+
on_epoch_end_cb=on_epoch_end_cb,
402410
)
403411

404412
@torchensemble_model_doc(

0 commit comments

Comments
 (0)