Skip to content

Commit d4c8a60

Browse files
authored
[ML] Line search feature bag fraction for classification and regression model training (#1761)
Following on from #1733, we can get further speedups by line searching for the best feature bag fraction for data sets where we only need a fraction of the features per tree. For example, training time on Higgs 1M drops from 2585s to 1742s and we actually get a small improvement in accuracy because our hyperparameter search region is better initialised. This makes three changes: 1. Adds a line search for the best initial feature bag fraction to use. 2. Adds a small linear penalty at most 1% minimum loss to encourage larger down sample factors and smaller feature bag fractions. 3. Handles better the case we have many features and relatively few training examples.
1 parent 7a6be22 commit d4c8a60

File tree

6 files changed

+265
-131
lines changed

6 files changed

+265
-131
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@
2828

2929
//=== Regressions
3030

31+
== {es} version 7.13.0
32+
33+
=== Enhancements
34+
35+
* Speed up training of regression and classification model training for data sets
36+
with many features. (See {ml-pull}1746[#1746].)
37+
3138
== {es} version 7.12.0
3239

3340
=== Enhancements

include/maths/CBoostedTreeFactory.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class MATHS_EXPORT CBoostedTreeFactory final {
119119
//! Set the number of training examples we need per feature we'll include.
120120
CBoostedTreeFactory& numberTopShapValues(std::size_t numberTopShapValues);
121121
//! Set the flag to enable or disable early stopping.
122-
CBoostedTreeFactory& earlyStoppingEnabled(bool earlyStoppingEnabled);
122+
CBoostedTreeFactory& earlyStoppingEnabled(bool enable);
123123

124124
//! Set pointer to the analysis instrumentation.
125125
CBoostedTreeFactory&
@@ -147,7 +147,8 @@ class MATHS_EXPORT CBoostedTreeFactory final {
147147
using TOptionalVector = boost::optional<TVector>;
148148
using TPackedBitVectorVec = std::vector<core::CPackedBitVector>;
149149
using TBoostedTreeImplUPtr = std::unique_ptr<CBoostedTreeImpl>;
150-
using TApplyRegularizer = std::function<bool(CBoostedTreeImpl&, double)>;
150+
using TApplyParameter = std::function<bool(CBoostedTreeImpl&, double)>;
151+
using TAdjustTestLoss = std::function<double(double, double, double)>;
151152

152153
private:
153154
CBoostedTreeFactory(std::size_t numberThreads, TLossFunctionUPtr loss);
@@ -190,6 +191,9 @@ class MATHS_EXPORT CBoostedTreeFactory final {
190191
//! search bounding box.
191192
void initializeUnsetRegularizationHyperparameters(core::CDataFrame& frame);
192193

194+
//! Estimate a good central value for the feature bag fraction search interval.
195+
void initializeUnsetFeatureBagFraction(core::CDataFrame& frame);
196+
193197
//! Estimates a good central value for the downsample factor search interval.
194198
void initializeUnsetDownsampleFactor(core::CDataFrame& frame);
195199

@@ -208,11 +212,12 @@ class MATHS_EXPORT CBoostedTreeFactory final {
208212
//! \return The interval to search during the main hyperparameter optimisation
209213
//! loop or null if this couldn't be found.
210214
TOptionalVector testLossLineSearch(core::CDataFrame& frame,
211-
const TApplyRegularizer& applyRegularizerStep,
215+
const TApplyParameter& applyParameterStep,
212216
double intervalLeftEnd,
213217
double intervalRightEnd,
214218
double returnedIntervalLeftEndOffset,
215-
double returnedIntervalRightEndOffset) const;
219+
double returnedIntervalRightEndOffset,
220+
const TAdjustTestLoss& adjustTestLoss = noopAdjustTestLoss) const;
216221

217222
//! Initialize the state for hyperparameter optimisation.
218223
void initializeHyperparameterOptimisation() const;
@@ -264,8 +269,8 @@ class MATHS_EXPORT CBoostedTreeFactory final {
264269
//! Stubs out persistence.
265270
static void noopRecordTrainingState(CBoostedTree::TPersistFunc);
266271

267-
//! Stop hyperparameter optimization early if the process is not promising.
268-
void stopHyperparameterOptimizationEarly(bool stopEarly);
272+
//! Stubs out test loss adjustment.
273+
static double noopAdjustTestLoss(double, double, double testLoss);
269274

270275
private:
271276
TOptionalDouble m_MinimumFrequencyToOneHotEncode;
@@ -280,6 +285,7 @@ class MATHS_EXPORT CBoostedTreeFactory final {
280285
std::size_t m_NumberThreads;
281286
TBoostedTreeImplUPtr m_TreeImpl;
282287
TVector m_LogDownsampleFactorSearchInterval;
288+
TVector m_LogFeatureBagFractionInterval;
283289
TVector m_LogDepthPenaltyMultiplierSearchInterval;
284290
TVector m_LogTreeSizePenaltyMultiplierSearchInterval;
285291
TVector m_LogLeafWeightPenaltyMultiplierSearchInterval;

include/maths/CBoostedTreeImpl.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,9 @@ class MATHS_EXPORT CBoostedTreeImpl final {
192192
E_TreeSizePenaltyMultiplierInitialized = 3,
193193
E_LeafWeightPenaltyMultiplierInitialized = 4,
194194
E_DownsampleFactorInitialized = 5,
195-
E_EtaInitialized = 6,
196-
E_FullyInitialized = 7
195+
E_FeatureBagFractionInitialized = 6,
196+
E_EtaInitialized = 7,
197+
E_FullyInitialized = 8
197198
};
198199

199200
private:

0 commit comments

Comments
 (0)