Skip to content

Commit 0614b01

Browse files
authored
[ML] Scale regularisers for final train (#1755)
As we move towards training for hyperparameter tuning on a small fraction of the data set and final training on more we will suffer issues with overfitting if we don't address the bias this introduces estimating regularisers. Interestingly, we already see a mismatch in train and test errors on larger data sets where we only use two-folds. I tested this correction, which is the one we use when we downsample, on a variety of data sets and we ended up with lower mismatch between train and test errors.
1 parent abbbda9 commit 0614b01

File tree

5 files changed

+38
-23
lines changed

5 files changed

+38
-23
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@
4545

4646
* Speed up training of regression and classification model training for data sets
4747
with many features. (See {ml-pull}1746[#1746].)
48+
* Avoid overfitting in final training by scaling regularizers to account for the
49+
difference in the number of training examples. This results in a better match
50+
between train and test error for classification and regression and often slightly
51+
improved test errors. (See {ml-pull}1755[#1755].)
4852

4953
== {es} version 7.12.0
5054

include/maths/CBoostedTreeImpl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,9 @@ class MATHS_EXPORT CBoostedTreeImpl final {
311311
//! Set the hyperparamaters from the best recorded.
312312
void restoreBestHyperparameters();
313313

314+
//! Scale the regulariser multipliers by \p scale.
315+
void scaleRegularizers(double scale);
316+
314317
//! Check invariants which are assumed to hold after restoring.
315318
void checkRestoredInvariants() const;
316319

lib/maths/CBoostedTreeFactory.cc

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -769,32 +769,23 @@ void CBoostedTreeFactory::initializeUnsetDownsampleFactor(core::CDataFrame& fram
769769
(logMinDownsampleFactor + logMaxDownsampleFactor) / 2.0};
770770
LOG_TRACE(<< "mean log downsample factor = " << meanLogDownSampleFactor);
771771

772-
double previousDownsampleFactor{m_TreeImpl->m_DownsampleFactor};
773-
double previousDepthPenaltyMultiplier{
772+
double initialDownsampleFactor{m_TreeImpl->m_DownsampleFactor};
773+
double initialDepthPenaltyMultiplier{
774774
m_TreeImpl->m_Regularization.depthPenaltyMultiplier()};
775-
double previousTreeSizePenaltyMultiplier{
775+
double initialTreeSizePenaltyMultiplier{
776776
m_TreeImpl->m_Regularization.treeSizePenaltyMultiplier()};
777-
double previousLeafWeightPenaltyMultiplier{
777+
double initialLeafWeightPenaltyMultiplier{
778778
m_TreeImpl->m_Regularization.leafWeightPenaltyMultiplier()};
779779

780780
// We need to scale the regularisation terms to account for the difference
781781
// in the downsample factor compared to the value used in the line search.
782782
auto scaleRegularizers = [&](CBoostedTreeImpl& tree, double downsampleFactor) {
783-
double scale{previousDownsampleFactor / downsampleFactor};
784-
if (tree.m_RegularizationOverride.depthPenaltyMultiplier() == boost::none) {
785-
tree.m_Regularization.depthPenaltyMultiplier(
786-
scale * previousDepthPenaltyMultiplier);
787-
}
788-
if (tree.m_RegularizationOverride.treeSizePenaltyMultiplier() ==
789-
boost::none) {
790-
tree.m_Regularization.treeSizePenaltyMultiplier(
791-
scale * previousTreeSizePenaltyMultiplier);
792-
}
793-
if (tree.m_RegularizationOverride.leafWeightPenaltyMultiplier() ==
794-
boost::none) {
795-
tree.m_Regularization.leafWeightPenaltyMultiplier(
796-
scale * previousLeafWeightPenaltyMultiplier);
797-
}
783+
double scale{initialDownsampleFactor / downsampleFactor};
784+
tree.m_Regularization.depthPenaltyMultiplier(initialDepthPenaltyMultiplier);
785+
tree.m_Regularization.treeSizePenaltyMultiplier(initialTreeSizePenaltyMultiplier);
786+
tree.m_Regularization.leafWeightPenaltyMultiplier(
787+
initialLeafWeightPenaltyMultiplier);
788+
tree.scaleRegularizers(scale);
798789
return scale;
799790
};
800791

lib/maths/CBoostedTreeImpl.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,8 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame,
261261
LOG_TRACE(<< "Test loss = " << m_BestForestTestLoss);
262262

263263
this->restoreBestHyperparameters();
264+
this->scaleRegularizers(allTrainingRowsMask.manhattan() /
265+
m_TrainingRowMasks[0].manhattan());
264266
this->startProgressMonitoringFinalTrain();
265267
std::tie(m_BestForest, std::ignore, std::ignore) = this->trainForest(
266268
frame, allTrainingRowsMask, allTrainingRowsMask, m_TrainingProgress);
@@ -1452,6 +1454,21 @@ void CBoostedTreeImpl::restoreBestHyperparameters() {
14521454
<< ", feature bag fraction* = " << m_FeatureBagFraction);
14531455
}
14541456

1457+
void CBoostedTreeImpl::scaleRegularizers(double scale) {
1458+
if (m_RegularizationOverride.depthPenaltyMultiplier() == boost::none) {
1459+
m_Regularization.depthPenaltyMultiplier(
1460+
scale * m_Regularization.depthPenaltyMultiplier());
1461+
}
1462+
if (m_RegularizationOverride.treeSizePenaltyMultiplier() == boost::none) {
1463+
m_Regularization.treeSizePenaltyMultiplier(
1464+
scale * m_Regularization.treeSizePenaltyMultiplier());
1465+
}
1466+
if (m_RegularizationOverride.leafWeightPenaltyMultiplier() == boost::none) {
1467+
m_Regularization.leafWeightPenaltyMultiplier(
1468+
scale * m_Regularization.leafWeightPenaltyMultiplier());
1469+
}
1470+
}
1471+
14551472
std::size_t CBoostedTreeImpl::numberHyperparametersToTune() const {
14561473
return m_RegularizationOverride.countNotSet() +
14571474
(m_DownsampleFactorOverride != boost::none ? 0 : 1) +

lib/maths/unittest/CBoostedTreeTest.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -882,7 +882,7 @@ BOOST_AUTO_TEST_CASE(testCategoricalRegressors) {
882882
LOG_DEBUG(<< "bias = " << modelBias);
883883
LOG_DEBUG(<< " R^2 = " << modelRSquared);
884884
BOOST_REQUIRE_CLOSE_ABSOLUTE(0.0, modelBias, 0.16);
885-
BOOST_TEST_REQUIRE(modelRSquared > 0.97);
885+
BOOST_TEST_REQUIRE(modelRSquared > 0.95);
886886
}
887887

888888
BOOST_AUTO_TEST_CASE(testFeatureBags) {
@@ -1301,13 +1301,13 @@ BOOST_AUTO_TEST_CASE(testBinomialLogisticRegression) {
13011301
LOG_DEBUG(<< "log relative error = "
13021302
<< maths::CBasicStatistics::mean(logRelativeError));
13031303

1304-
BOOST_TEST_REQUIRE(maths::CBasicStatistics::mean(logRelativeError) < 0.681);
1304+
BOOST_TEST_REQUIRE(maths::CBasicStatistics::mean(logRelativeError) < 0.69);
13051305
meanLogRelativeError.add(maths::CBasicStatistics::mean(logRelativeError));
13061306
}
13071307

13081308
LOG_DEBUG(<< "mean log relative error = "
13091309
<< maths::CBasicStatistics::mean(meanLogRelativeError));
1310-
BOOST_TEST_REQUIRE(maths::CBasicStatistics::mean(meanLogRelativeError) < 0.51);
1310+
BOOST_TEST_REQUIRE(maths::CBasicStatistics::mean(meanLogRelativeError) < 0.52);
13111311
}
13121312

13131313
BOOST_AUTO_TEST_CASE(testImbalancedClasses) {
@@ -1389,7 +1389,7 @@ BOOST_AUTO_TEST_CASE(testImbalancedClasses) {
13891389
LOG_DEBUG(<< "recalls = " << core::CContainerPrinter::print(recalls));
13901390

13911391
BOOST_TEST_REQUIRE(std::fabs(precisions[0] - precisions[1]) < 0.1);
1392-
BOOST_TEST_REQUIRE(std::fabs(recalls[0] - recalls[1]) < 0.11);
1392+
BOOST_TEST_REQUIRE(std::fabs(recalls[0] - recalls[1]) < 0.13);
13931393
}
13941394

13951395
BOOST_AUTO_TEST_CASE(testClassificationWeightsOverride) {

0 commit comments

Comments
 (0)