From 7c43b8ca5c41503ed2da89534cbdacbd63c739bd Mon Sep 17 00:00:00 2001 From: Tom Veasey Date: Fri, 19 Feb 2021 16:10:20 +0000 Subject: [PATCH 1/3] Scale regularisers for final train --- include/maths/CBoostedTreeImpl.h | 3 +++ lib/maths/CBoostedTreeFactory.cc | 29 +++++++++----------------- lib/maths/CBoostedTreeImpl.cc | 17 +++++++++++++++ lib/maths/unittest/CBoostedTreeTest.cc | 2 +- 4 files changed, 31 insertions(+), 20 deletions(-) diff --git a/include/maths/CBoostedTreeImpl.h b/include/maths/CBoostedTreeImpl.h index f201cd8f3c..b3ee21aa76 100644 --- a/include/maths/CBoostedTreeImpl.h +++ b/include/maths/CBoostedTreeImpl.h @@ -311,6 +311,9 @@ class MATHS_EXPORT CBoostedTreeImpl final { //! Set the hyperparamaters from the best recorded. void restoreBestHyperparameters(); + //! Scale the regulariser multipliers by \p scale. + void scaleRegularizers(double scale); + //! Check invariants which are assumed to hold after restoring. void checkRestoredInvariants() const; diff --git a/lib/maths/CBoostedTreeFactory.cc b/lib/maths/CBoostedTreeFactory.cc index c2477c3f4a..d06107467f 100644 --- a/lib/maths/CBoostedTreeFactory.cc +++ b/lib/maths/CBoostedTreeFactory.cc @@ -769,32 +769,23 @@ void CBoostedTreeFactory::initializeUnsetDownsampleFactor(core::CDataFrame& fram (logMinDownsampleFactor + logMaxDownsampleFactor) / 2.0}; LOG_TRACE(<< "mean log downsample factor = " << meanLogDownSampleFactor); - double previousDownsampleFactor{m_TreeImpl->m_DownsampleFactor}; - double previousDepthPenaltyMultiplier{ + double initialDownsampleFactor{m_TreeImpl->m_DownsampleFactor}; + double initialDepthPenaltyMultiplier{ m_TreeImpl->m_Regularization.depthPenaltyMultiplier()}; - double previousTreeSizePenaltyMultiplier{ + double initialTreeSizePenaltyMultiplier{ m_TreeImpl->m_Regularization.treeSizePenaltyMultiplier()}; - double previousLeafWeightPenaltyMultiplier{ + double initialLeafWeightPenaltyMultiplier{ m_TreeImpl->m_Regularization.leafWeightPenaltyMultiplier()}; // We need to scale the regularisation terms to account for the difference // in the downsample factor compared to the value used in the line search. auto scaleRegularizers = [&](CBoostedTreeImpl& tree, double downsampleFactor) { - double scale{previousDownsampleFactor / downsampleFactor}; - if (tree.m_RegularizationOverride.depthPenaltyMultiplier() == boost::none) { - tree.m_Regularization.depthPenaltyMultiplier( - scale * previousDepthPenaltyMultiplier); - } - if (tree.m_RegularizationOverride.treeSizePenaltyMultiplier() == - boost::none) { - tree.m_Regularization.treeSizePenaltyMultiplier( - scale * previousTreeSizePenaltyMultiplier); - } - if (tree.m_RegularizationOverride.leafWeightPenaltyMultiplier() == - boost::none) { - tree.m_Regularization.leafWeightPenaltyMultiplier( - scale * previousLeafWeightPenaltyMultiplier); - } + double scale{initialDownsampleFactor / downsampleFactor}; + tree.m_Regularization.depthPenaltyMultiplier(initialDepthPenaltyMultiplier); + tree.m_Regularization.treeSizePenaltyMultiplier(initialTreeSizePenaltyMultiplier); + tree.m_Regularization.leafWeightPenaltyMultiplier( + initialLeafWeightPenaltyMultiplier); + tree.scaleRegularizers(scale); return scale; }; diff --git a/lib/maths/CBoostedTreeImpl.cc b/lib/maths/CBoostedTreeImpl.cc index 6303701dfb..8651caa9e8 100644 --- a/lib/maths/CBoostedTreeImpl.cc +++ b/lib/maths/CBoostedTreeImpl.cc @@ -261,6 +261,8 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame, LOG_TRACE(<< "Test loss = " << m_BestForestTestLoss); this->restoreBestHyperparameters(); + this->scaleRegularizers(allTrainingRowsMask.manhattan() / + m_TrainingRowMasks[0].manhattan()); this->startProgressMonitoringFinalTrain(); std::tie(m_BestForest, std::ignore, std::ignore) = this->trainForest( frame, allTrainingRowsMask, allTrainingRowsMask, m_TrainingProgress); @@ -1452,6 +1454,21 @@ void CBoostedTreeImpl::restoreBestHyperparameters() { << ", feature bag fraction* = " << m_FeatureBagFraction); } +void CBoostedTreeImpl::scaleRegularizers(double scale) { + if (m_RegularizationOverride.depthPenaltyMultiplier() == boost::none) { + m_Regularization.depthPenaltyMultiplier( + scale * m_Regularization.depthPenaltyMultiplier()); + } + if (m_RegularizationOverride.treeSizePenaltyMultiplier() == boost::none) { + m_Regularization.treeSizePenaltyMultiplier( + scale * m_Regularization.treeSizePenaltyMultiplier()); + } + if (m_RegularizationOverride.leafWeightPenaltyMultiplier() == boost::none) { + m_Regularization.leafWeightPenaltyMultiplier( + scale * m_Regularization.leafWeightPenaltyMultiplier()); + } +} + std::size_t CBoostedTreeImpl::numberHyperparametersToTune() const { return m_RegularizationOverride.countNotSet() + (m_DownsampleFactorOverride != boost::none ? 0 : 1) + diff --git a/lib/maths/unittest/CBoostedTreeTest.cc b/lib/maths/unittest/CBoostedTreeTest.cc index 9e13eb52b3..b61e606306 100644 --- a/lib/maths/unittest/CBoostedTreeTest.cc +++ b/lib/maths/unittest/CBoostedTreeTest.cc @@ -1389,7 +1389,7 @@ BOOST_AUTO_TEST_CASE(testImbalancedClasses) { LOG_DEBUG(<< "recalls = " << core::CContainerPrinter::print(recalls)); BOOST_TEST_REQUIRE(std::fabs(precisions[0] - precisions[1]) < 0.1); - BOOST_TEST_REQUIRE(std::fabs(recalls[0] - recalls[1]) < 0.11); + BOOST_TEST_REQUIRE(std::fabs(recalls[0] - recalls[1]) < 0.13); } BOOST_AUTO_TEST_CASE(testClassificationWeightsOverride) { From 2c23e249e34aef74ff23e2ac7b1aabc304ba1b61 Mon Sep 17 00:00:00 2001 From: Tom Veasey Date: Mon, 22 Feb 2021 10:44:01 +0000 Subject: [PATCH 2/3] Test thresholds --- lib/maths/unittest/CBoostedTreeTest.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/maths/unittest/CBoostedTreeTest.cc b/lib/maths/unittest/CBoostedTreeTest.cc index b61e606306..6dbc170fda 100644 --- a/lib/maths/unittest/CBoostedTreeTest.cc +++ b/lib/maths/unittest/CBoostedTreeTest.cc @@ -882,7 +882,7 @@ BOOST_AUTO_TEST_CASE(testCategoricalRegressors) { LOG_DEBUG(<< "bias = " << modelBias); LOG_DEBUG(<< " R^2 = " << modelRSquared); BOOST_REQUIRE_CLOSE_ABSOLUTE(0.0, modelBias, 0.16); - BOOST_TEST_REQUIRE(modelRSquared > 0.97); + BOOST_TEST_REQUIRE(modelRSquared > 0.95); } BOOST_AUTO_TEST_CASE(testFeatureBags) { @@ -1301,13 +1301,13 @@ BOOST_AUTO_TEST_CASE(testBinomialLogisticRegression) { LOG_DEBUG(<< "log relative error = " << maths::CBasicStatistics::mean(logRelativeError)); - BOOST_TEST_REQUIRE(maths::CBasicStatistics::mean(logRelativeError) < 0.681); + BOOST_TEST_REQUIRE(maths::CBasicStatistics::mean(logRelativeError) < 0.69); meanLogRelativeError.add(maths::CBasicStatistics::mean(logRelativeError)); } LOG_DEBUG(<< "mean log relative error = " << maths::CBasicStatistics::mean(meanLogRelativeError)); - BOOST_TEST_REQUIRE(maths::CBasicStatistics::mean(meanLogRelativeError) < 0.51); + BOOST_TEST_REQUIRE(maths::CBasicStatistics::mean(meanLogRelativeError) < 0.52); } BOOST_AUTO_TEST_CASE(testImbalancedClasses) { From 691b36eb24930cc119272b0e17cba903e4c5c2f9 Mon Sep 17 00:00:00 2001 From: Tom Veasey Date: Mon, 22 Feb 2021 10:47:55 +0000 Subject: [PATCH 3/3] Docs --- docs/CHANGELOG.asciidoc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/CHANGELOG.asciidoc b/docs/CHANGELOG.asciidoc index de88627959..d4f8070571 100644 --- a/docs/CHANGELOG.asciidoc +++ b/docs/CHANGELOG.asciidoc @@ -45,6 +45,10 @@ * Speed up training of regression and classification model training for data sets with many features. (See {ml-pull}1746[#1746].) +* Avoid overfitting in final training by scaling regularizers to account for the + difference in the number of training examples. This results in a better match + between train and test error for classification and regression and often slightly + improved test errors. (See {ml-pull}1755[#1755].) == {es} version 7.12.0